diff --git a/fairseq/fairseq.egg-info/PKG-INFO b/fairseq/fairseq.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..c9bde99ea7cb63bdd1331b4765eb2077c1dafeeb --- /dev/null +++ b/fairseq/fairseq.egg-info/PKG-INFO @@ -0,0 +1,283 @@ +Metadata-Version: 2.2 +Name: fairseq +Version: 0.12.2 +Summary: Facebook AI Research Sequence-to-Sequence Toolkit +Home-page: https://github.com/pytorch/fairseq +Classifier: Intended Audience :: Science/Research +Classifier: License :: OSI Approved :: MIT License +Classifier: Programming Language :: Python :: 3.6 +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence +Description-Content-Type: text/markdown +License-File: LICENSE +Requires-Dist: cffi +Requires-Dist: cython +Requires-Dist: hydra-core<1.1,>=1.0.7 +Requires-Dist: omegaconf<2.1 +Requires-Dist: numpy>=1.21.3 +Requires-Dist: regex +Requires-Dist: sacrebleu>=1.4.12 +Requires-Dist: torch>=1.13 +Requires-Dist: tqdm +Requires-Dist: bitarray +Requires-Dist: torchaudio>=0.8.0 +Requires-Dist: scikit-learn +Requires-Dist: packaging +Provides-Extra: dev +Requires-Dist: flake8; extra == "dev" +Requires-Dist: pytest; extra == "dev" +Requires-Dist: black==22.3.0; extra == "dev" +Provides-Extra: docs +Requires-Dist: sphinx; extra == "docs" +Requires-Dist: sphinx-argparse; extra == "docs" +Dynamic: classifier +Dynamic: description +Dynamic: description-content-type +Dynamic: home-page +Dynamic: provides-extra +Dynamic: requires-dist +Dynamic: summary + +

+ +
+
+ Support Ukraine + MIT License + Latest Release + Build Status + Documentation Status + CicleCI Status +

+ +-------------------------------------------------------------------------------- + +Fairseq(-py) is a sequence modeling toolkit that allows researchers and +developers to train custom models for translation, summarization, language +modeling and other text generation tasks. + +We provide reference implementations of various sequence modeling papers: + +
List of implemented papers

+ +* **Convolutional Neural Networks (CNN)** + + [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md) + + [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md) + + [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel) + + [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md) + + [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) +* **LightConv and DynamicConv models** + + [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md) +* **Long Short-Term Memory (LSTM) networks** + + Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015) +* **Transformer (self-attention) networks** + + Attention Is All You Need (Vaswani et al., 2017) + + [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md) + + [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) + + [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md) + + [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md) + + [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md) + + [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md) + + [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) + + [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) + + [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) + + [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md ) + + [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) + + [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) + + [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) + + [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) + + [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md) + + [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md) + + [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) + + [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md) + + [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979) + + [Self-training and Pre-training are Complementary for Speech Recognition (Xu et al., 2020)](https://arxiv.org/abs/2010.11430) + + [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027) + + [Unsupervised Speech Recognition (Baevski, et al., 2021)](https://arxiv.org/abs/2105.11084) + + [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition (Xu et al., 2021)](https://arxiv.org/abs/2109.11680) + + [VideoCLIP: Contrastive Pre-training for Zero-shot Video-Text Understanding (Xu et. al., 2021)](https://arxiv.org/pdf/2109.14084.pdf) + + [VLM: Task-agnostic Video-Language Model Pre-training for Video Understanding (Xu et. al., 2021)](https://aclanthology.org/2021.findings-acl.370.pdf) + + [NormFormer: Improved Transformer Pretraining with Extra Normalization (Shleifer et. al, 2021)](examples/normformer/README.md) +* **Non-autoregressive Transformers** + + Non-Autoregressive Neural Machine Translation (Gu et al., 2017) + + Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018) + + Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019) + + Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019) + + [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) +* **Finetuning** + + [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md) + +

+ +### What's New: +* May 2023 [Released models for Scaling Speech Technology to 1,000+ Languages (Pratap, et al., 2023)](examples/mms/README.md) +* June 2022 [Released code for wav2vec-U 2.0 from Towards End-to-end Unsupervised Speech Recognition (Liu, et al., 2022)](examples/wav2vec/unsupervised/README.md) +* May 2022 [Integration with xFormers](https://github.com/facebookresearch/xformers) +* December 2021 [Released Direct speech-to-speech translation code](examples/speech_to_speech/README.md) +* October 2021 [Released VideoCLIP and VLM models](examples/MMPT/README.md) +* October 2021 [Released multilingual finetuned XLSR-53 model](examples/wav2vec/README.md) +* September 2021 [`master` branch renamed to `main`](https://github.com/github/renaming). +* July 2021 [Released DrNMT code](examples/discriminative_reranking_nmt/README.md) +* July 2021 [Released Robust wav2vec 2.0 model](examples/wav2vec/README.md) +* June 2021 [Released XLMR-XL and XLMR-XXL models](examples/xlmr/README.md) +* May 2021 [Released Unsupervised Speech Recognition code](examples/wav2vec/unsupervised/README.md) +* March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md) +* February 2021 [Added LASER training code](examples/laser/README.md) +* December 2020: [Added Adaptive Attention Span code](examples/adaptive_span/README.md) +* December 2020: [GottBERT model and code released](examples/gottbert/README.md) +* November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework + * [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md) +* November 2020: [fairseq 0.10.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.10.0) +* October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md) +* October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md) +* October 2020: [Added CRISS models and code](examples/criss/README.md) + +
Previous updates

+ +* September 2020: [Added Linformer code](examples/linformer/README.md) +* September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md) +* August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md) +* August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md) +* July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md) +* May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq) +* April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md) +* April 2020: [Quant-Noise code released](examples/quant_noise/README.md) +* April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md) +* March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md) +* February 2020: [mBART model and code released](examples/mbart/README.md) +* February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/main/examples/backtranslation#training-your-own-model-wmt18-english-german) +* December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0) +* November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example) +* November 2019: [CamemBERT model and code released](examples/camembert/README.md) +* November 2019: [BART model and code released](examples/bart/README.md) +* November 2019: [XLM-R models and code released](examples/xlmr/README.md) +* September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md) +* August 2019: [WMT'19 models released](examples/wmt19/README.md) +* July 2019: fairseq relicensed under MIT license +* July 2019: [RoBERTa models and code released](examples/roberta/README.md) +* June 2019: [wav2vec models and code released](examples/wav2vec/README.md) + +

+ +### Features: + +* multi-GPU training on one machine or across multiple machines (data and model parallel) +* fast generation on both CPU and GPU with multiple search algorithms implemented: + + beam search + + Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424)) + + sampling (unconstrained, top-k and top-p/nucleus) + + [lexically constrained decoding](examples/constrained_decoding/README.md) (Post & Vilar, 2018) +* [gradient accumulation](https://fairseq.readthedocs.io/en/latest/getting_started.html#large-mini-batch-training-with-delayed-updates) enables training with large mini-batches even on a single GPU +* [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores)) +* [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers +* [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration +* [full parameter and optimizer state sharding](examples/fully_sharded_data_parallel/README.md) +* [offloading parameters to CPU](examples/fully_sharded_data_parallel/README.md) + +We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples) +with a convenient `torch.hub` interface: + +``` python +en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model') +en2de.translate('Hello world', beam=5) +# 'Hallo Welt' +``` + +See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/) +and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples. + +# Requirements and Installation + +* [PyTorch](http://pytorch.org/) version >= 1.10.0 +* Python version >= 3.8 +* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl) +* **To install fairseq** and develop locally: + +``` bash +git clone https://github.com/pytorch/fairseq +cd fairseq +pip install --editable ./ + +# on MacOS: +# CFLAGS="-stdlib=libc++" pip install --editable ./ + +# to install the latest stable release (0.10.x) +# pip install fairseq +``` + +* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library: + +``` bash +git clone https://github.com/NVIDIA/apex +cd apex +pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \ + --global-option="--deprecated_fused_adam" --global-option="--xentropy" \ + --global-option="--fast_multihead_attn" ./ +``` + +* **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow` +* If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size` + as command line options to `nvidia-docker run` . + +# Getting Started + +The [full documentation](https://fairseq.readthedocs.io/) contains instructions +for getting started, training new models and extending fairseq with new model +types and tasks. + +# Pre-trained models and examples + +We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below, +as well as example training and evaluation commands. + +* [Translation](examples/translation/README.md): convolutional and transformer models are available +* [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available + +We also have more detailed READMEs to reproduce results from specific papers: + +* [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale (Babu et al., 2021)](examples/wav2vec/xlsr/README.md) +* [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) +* [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) +* [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) +* [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md) +* [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) +* [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) +* [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md) +* [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md) +* [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) +* [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) +* [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) +* [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) +* [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) +* [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md) +* [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) +* [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel) +* [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md) +* [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md) +* [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md) +* [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md) + +# Join the fairseq community + +* Twitter: https://twitter.com/fairseq +* Facebook page: https://www.facebook.com/groups/fairseq.users +* Google group: https://groups.google.com/forum/#!forum/fairseq-users + +# License + +fairseq(-py) is MIT-licensed. +The license applies to the pre-trained models as well. + +# Citation + +Please cite as: + +``` bibtex +@inproceedings{ott2019fairseq, + title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling}, + author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli}, + booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations}, + year = {2019}, +} +``` diff --git a/fairseq/fairseq.egg-info/SOURCES.txt b/fairseq/fairseq.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..99a0c8b27e21147bb21b33b7c9d3754e0b1e3d68 --- /dev/null +++ b/fairseq/fairseq.egg-info/SOURCES.txt @@ -0,0 +1,1546 @@ +LICENSE +MANIFEST.in +README.md +pyproject.toml +setup.cfg +setup.py +examples/operators/alignment_train_cpu.cpp +examples/operators/alignment_train_cuda.cpp +examples/operators/alignment_train_kernel.cu +fairseq/__init__.py +fairseq/binarizer.py +fairseq/checkpoint_utils.py +fairseq/file_chunker_utils.py +fairseq/file_io.py +fairseq/file_utils.py +fairseq/hub_utils.py +fairseq/incremental_decoding_utils.py +fairseq/iterative_refinement_generator.py +fairseq/nan_detector.py +fairseq/ngram_repeat_block.py +fairseq/options.py +fairseq/pdb.py +fairseq/quantization_utils.py +fairseq/registry.py +fairseq/search.py +fairseq/sequence_generator.py +fairseq/sequence_scorer.py +fairseq/speech_generator.py +fairseq/token_generation_constraints.py +fairseq/tokenizer.py +fairseq/trainer.py +fairseq/utils.py +fairseq/version.py +fairseq/version.txt +fairseq.egg-info/PKG-INFO +fairseq.egg-info/SOURCES.txt +fairseq.egg-info/dependency_links.txt +fairseq.egg-info/entry_points.txt +fairseq.egg-info/not-zip-safe +fairseq.egg-info/requires.txt +fairseq.egg-info/top_level.txt +fairseq/benchmark/__init__.py +fairseq/benchmark/benchmark_multihead_attention.py +fairseq/benchmark/dummy_dataset.py +fairseq/benchmark/dummy_lm.py +fairseq/benchmark/dummy_masked_lm.py +fairseq/benchmark/dummy_model.py +fairseq/benchmark/dummy_mt.py +fairseq/clib/cuda/ngram_repeat_block_cuda.cpp +fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu +fairseq/clib/libbase/balanced_assignment.cpp +fairseq/clib/libbleu/libbleu.cpp +fairseq/clib/libbleu/module.cpp +fairseq/clib/libnat/edit_dist.cpp +fairseq/clib/libnat_cuda/binding.cpp +fairseq/clib/libnat_cuda/edit_dist.cu +fairseq/config/__init__.py +fairseq/config/config.yaml +fairseq/config/fb_run_config/slurm.yaml +fairseq/config/model/transformer_lm/transformer_lm_baevski_gbw.yaml +fairseq/config/model/transformer_lm/transformer_lm_baevski_wiki103.yaml +fairseq/config/model/transformer_lm/transformer_lm_big.yaml +fairseq/config/model/transformer_lm/transformer_lm_gbw.yaml +fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml +fairseq/config/model/transformer_lm/transformer_lm_gpt2_big.yaml +fairseq/config/model/transformer_lm/transformer_lm_gpt2_medium.yaml +fairseq/config/model/transformer_lm/transformer_lm_gpt2_small.yaml +fairseq/config/model/transformer_lm/transformer_lm_wiki103.yaml +fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml +fairseq/config/model/wav2vec2/wav2vec2_base.yaml +fairseq/config/model/wav2vec2/wav2vec2_large.yaml +fairseq/criterions/__init__.py +fairseq/criterions/adaptive_loss.py +fairseq/criterions/composite_loss.py +fairseq/criterions/cross_entropy.py +fairseq/criterions/ctc.py +fairseq/criterions/fairseq_criterion.py +fairseq/criterions/fastspeech2_loss.py +fairseq/criterions/hubert_criterion.py +fairseq/criterions/label_smoothed_cross_entropy.py +fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py +fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py +fairseq/criterions/label_smoothed_cross_entropy_with_ctc.py +fairseq/criterions/label_smoothed_cross_entropy_with_rdrop.py +fairseq/criterions/legacy_masked_lm.py +fairseq/criterions/masked_lm.py +fairseq/criterions/model_criterion.py +fairseq/criterions/nat_loss.py +fairseq/criterions/sentence_prediction.py +fairseq/criterions/sentence_prediction_adapters.py +fairseq/criterions/sentence_ranking.py +fairseq/criterions/speech_dlm_criterion.py +fairseq/criterions/speech_to_speech_criterion.py +fairseq/criterions/speech_ulm_criterion.py +fairseq/criterions/tacotron2_loss.py +fairseq/criterions/wav2vec_criterion.py +fairseq/data/__init__.py +fairseq/data/add_class_target_dataset.py +fairseq/data/add_target_dataset.py +fairseq/data/append_token_dataset.py +fairseq/data/backtranslation_dataset.py +fairseq/data/base_wrapper_dataset.py +fairseq/data/bucket_pad_length_dataset.py +fairseq/data/codedataset.py +fairseq/data/colorize_dataset.py +fairseq/data/concat_dataset.py +fairseq/data/concat_sentences_dataset.py +fairseq/data/data_utils.py +fairseq/data/data_utils_fast.pyx +fairseq/data/denoising_dataset.py +fairseq/data/dictionary.py +fairseq/data/fairseq_dataset.py +fairseq/data/fasta_dataset.py +fairseq/data/id_dataset.py +fairseq/data/indexed_dataset.py +fairseq/data/iterators.py +fairseq/data/language_pair_dataset.py +fairseq/data/list_dataset.py +fairseq/data/lm_context_window_dataset.py +fairseq/data/lru_cache_dataset.py +fairseq/data/mask_tokens_dataset.py +fairseq/data/monolingual_dataset.py +fairseq/data/multi_corpus_dataset.py +fairseq/data/multi_corpus_sampled_dataset.py +fairseq/data/nested_dictionary_dataset.py +fairseq/data/noising.py +fairseq/data/num_samples_dataset.py +fairseq/data/numel_dataset.py +fairseq/data/offset_tokens_dataset.py +fairseq/data/pad_dataset.py +fairseq/data/padding_mask_dataset.py +fairseq/data/plasma_utils.py +fairseq/data/prepend_dataset.py +fairseq/data/prepend_token_dataset.py +fairseq/data/raw_label_dataset.py +fairseq/data/replace_dataset.py +fairseq/data/resampling_dataset.py +fairseq/data/roll_dataset.py +fairseq/data/round_robin_zip_datasets.py +fairseq/data/shorten_dataset.py +fairseq/data/sort_dataset.py +fairseq/data/span_mask_tokens_dataset.py +fairseq/data/speech_dlm_dataset.py +fairseq/data/strip_token_dataset.py +fairseq/data/subsample_dataset.py +fairseq/data/text_compressor.py +fairseq/data/token_block_dataset.py +fairseq/data/token_block_utils_fast.pyx +fairseq/data/transform_eos_concat_langpair_dataset.py +fairseq/data/transform_eos_dataset.py +fairseq/data/transform_eos_lang_pair_dataset.py +fairseq/data/audio/__init__.py +fairseq/data/audio/audio_utils.py +fairseq/data/audio/data_cfg.py +fairseq/data/audio/frm_text_to_speech_dataset.py +fairseq/data/audio/hubert_dataset.py +fairseq/data/audio/multi_modality_dataset.py +fairseq/data/audio/raw_audio_dataset.py +fairseq/data/audio/speech_to_speech_dataset.py +fairseq/data/audio/speech_to_text_dataset.py +fairseq/data/audio/speech_to_text_joint_dataset.py +fairseq/data/audio/text_to_speech_dataset.py +fairseq/data/audio/dataset_transforms/__init__.py +fairseq/data/audio/dataset_transforms/concataugment.py +fairseq/data/audio/dataset_transforms/noisyoverlapaugment.py +fairseq/data/audio/feature_transforms/__init__.py +fairseq/data/audio/feature_transforms/delta_deltas.py +fairseq/data/audio/feature_transforms/global_cmvn.py +fairseq/data/audio/feature_transforms/specaugment.py +fairseq/data/audio/feature_transforms/utterance_cmvn.py +fairseq/data/audio/waveform_transforms/__init__.py +fairseq/data/audio/waveform_transforms/noiseaugment.py +fairseq/data/encoders/__init__.py +fairseq/data/encoders/byte_bpe.py +fairseq/data/encoders/byte_utils.py +fairseq/data/encoders/bytes.py +fairseq/data/encoders/characters.py +fairseq/data/encoders/fastbpe.py +fairseq/data/encoders/gpt2_bpe.py +fairseq/data/encoders/gpt2_bpe_utils.py +fairseq/data/encoders/hf_bert_bpe.py +fairseq/data/encoders/hf_byte_bpe.py +fairseq/data/encoders/moses_tokenizer.py +fairseq/data/encoders/nltk_tokenizer.py +fairseq/data/encoders/sentencepiece_bpe.py +fairseq/data/encoders/space_tokenizer.py +fairseq/data/encoders/subword_nmt_bpe.py +fairseq/data/encoders/utils.py +fairseq/data/huffman/__init__.py +fairseq/data/huffman/huffman_coder.py +fairseq/data/huffman/huffman_mmap_indexed_dataset.py +fairseq/data/legacy/__init__.py +fairseq/data/legacy/block_pair_dataset.py +fairseq/data/legacy/masked_lm_dataset.py +fairseq/data/legacy/masked_lm_dictionary.py +fairseq/data/multilingual/__init__.py +fairseq/data/multilingual/multilingual_data_manager.py +fairseq/data/multilingual/multilingual_utils.py +fairseq/data/multilingual/sampled_multi_dataset.py +fairseq/data/multilingual/sampled_multi_epoch_dataset.py +fairseq/data/multilingual/sampling_method.py +fairseq/dataclass/__init__.py +fairseq/dataclass/configs.py +fairseq/dataclass/constants.py +fairseq/dataclass/initialize.py +fairseq/dataclass/utils.py +fairseq/distributed/__init__.py +fairseq/distributed/distributed_timeout_wrapper.py +fairseq/distributed/fully_sharded_data_parallel.py +fairseq/distributed/legacy_distributed_data_parallel.py +fairseq/distributed/module_proxy_wrapper.py +fairseq/distributed/tpu_distributed_data_parallel.py +fairseq/distributed/utils.py +fairseq/examples/.gitignore +fairseq/examples/__init__.py +fairseq/examples/MMPT/.gitignore +fairseq/examples/MMPT/CONFIG.md +fairseq/examples/MMPT/DATASET.md +fairseq/examples/MMPT/README.md +fairseq/examples/MMPT/endtask.md +fairseq/examples/MMPT/locallaunch.py +fairseq/examples/MMPT/pretraining.md +fairseq/examples/MMPT/setup.py +fairseq/examples/MMPT/videoclip.png +fairseq/examples/MMPT/vlm.png +fairseq/examples/MMPT/mmpt/__init__.py +fairseq/examples/MMPT/mmpt/datasets/__init__.py +fairseq/examples/MMPT/mmpt/datasets/fairseqmmdataset.py +fairseq/examples/MMPT/mmpt/datasets/mmdataset.py +fairseq/examples/MMPT/mmpt/evaluators/__init__.py +fairseq/examples/MMPT/mmpt/evaluators/evaluator.py +fairseq/examples/MMPT/mmpt/evaluators/metric.py +fairseq/examples/MMPT/mmpt/evaluators/predictor.py +fairseq/examples/MMPT/mmpt/losses/__init__.py +fairseq/examples/MMPT/mmpt/losses/fairseqmmloss.py +fairseq/examples/MMPT/mmpt/losses/loss.py +fairseq/examples/MMPT/mmpt/losses/nce.py +fairseq/examples/MMPT/mmpt/models/__init__.py +fairseq/examples/MMPT/mmpt/models/fairseqmmmodel.py +fairseq/examples/MMPT/mmpt/models/mmfusion.py +fairseq/examples/MMPT/mmpt/models/mmfusionnlg.py +fairseq/examples/MMPT/mmpt/models/transformermodel.py +fairseq/examples/MMPT/mmpt/modules/__init__.py +fairseq/examples/MMPT/mmpt/modules/mm.py +fairseq/examples/MMPT/mmpt/modules/retri.py +fairseq/examples/MMPT/mmpt/modules/vectorpool.py +fairseq/examples/MMPT/mmpt/processors/__init__.py +fairseq/examples/MMPT/mmpt/processors/dedupprocessor.py +fairseq/examples/MMPT/mmpt/processors/dsprocessor.py +fairseq/examples/MMPT/mmpt/processors/how2processor.py +fairseq/examples/MMPT/mmpt/processors/how2retriprocessor.py +fairseq/examples/MMPT/mmpt/processors/processor.py +fairseq/examples/MMPT/mmpt/processors/models/s3dg.py +fairseq/examples/MMPT/mmpt/tasks/__init__.py +fairseq/examples/MMPT/mmpt/tasks/fairseqmmtask.py +fairseq/examples/MMPT/mmpt/tasks/milncetask.py +fairseq/examples/MMPT/mmpt/tasks/retritask.py +fairseq/examples/MMPT/mmpt/tasks/task.py +fairseq/examples/MMPT/mmpt/tasks/vlmtask.py +fairseq/examples/MMPT/mmpt/utils/__init__.py +fairseq/examples/MMPT/mmpt/utils/load_config.py +fairseq/examples/MMPT/mmpt/utils/shardedtensor.py +fairseq/examples/MMPT/mmpt_cli/localjob.py +fairseq/examples/MMPT/mmpt_cli/predict.py +fairseq/examples/MMPT/projects/mfmmlm.yaml +fairseq/examples/MMPT/projects/mtm/mmfusionmtm.yaml +fairseq/examples/MMPT/projects/mtm/vlm.yaml +fairseq/examples/MMPT/projects/mtm/vlm/coin.yaml +fairseq/examples/MMPT/projects/mtm/vlm/crosstask.yaml +fairseq/examples/MMPT/projects/mtm/vlm/how2.yaml +fairseq/examples/MMPT/projects/mtm/vlm/test_coin.yaml +fairseq/examples/MMPT/projects/mtm/vlm/test_crosstask.yaml +fairseq/examples/MMPT/projects/mtm/vlm/test_crosstask_zs.yaml +fairseq/examples/MMPT/projects/mtm/vlm/test_vtt.yaml +fairseq/examples/MMPT/projects/mtm/vlm/test_vttqa.yaml +fairseq/examples/MMPT/projects/mtm/vlm/test_youcook.yaml +fairseq/examples/MMPT/projects/mtm/vlm/test_youcookcap.yaml +fairseq/examples/MMPT/projects/mtm/vlm/vtt.yaml +fairseq/examples/MMPT/projects/mtm/vlm/vttqa.yaml +fairseq/examples/MMPT/projects/mtm/vlm/youcook.yaml +fairseq/examples/MMPT/projects/mtm/vlm/youcookcap.yaml +fairseq/examples/MMPT/projects/retri/videoclip.yaml +fairseq/examples/MMPT/projects/retri/videoretri.yaml +fairseq/examples/MMPT/projects/retri/videoclip/coin_videoclip.yaml +fairseq/examples/MMPT/projects/retri/videoclip/crosstask_videoclip.yaml +fairseq/examples/MMPT/projects/retri/videoclip/how2.yaml +fairseq/examples/MMPT/projects/retri/videoclip/test_coin_videoclip.yaml +fairseq/examples/MMPT/projects/retri/videoclip/test_coin_zs.yaml +fairseq/examples/MMPT/projects/retri/videoclip/test_crosstask_videoclip.yaml +fairseq/examples/MMPT/projects/retri/videoclip/test_crosstask_zs_videoclip.yaml +fairseq/examples/MMPT/projects/retri/videoclip/test_didemo_zs.yaml +fairseq/examples/MMPT/projects/retri/videoclip/test_vtt_videoclip.yaml +fairseq/examples/MMPT/projects/retri/videoclip/test_vtt_zs.yaml +fairseq/examples/MMPT/projects/retri/videoclip/test_vttqa_videoclip.yaml +fairseq/examples/MMPT/projects/retri/videoclip/test_vttqa_zs.yaml +fairseq/examples/MMPT/projects/retri/videoclip/test_youcook_videoclip.yaml +fairseq/examples/MMPT/projects/retri/videoclip/test_youcook_zs.yaml +fairseq/examples/MMPT/projects/retri/videoclip/vtt_videoclip.yaml +fairseq/examples/MMPT/projects/retri/videoclip/vttqa_videoclip.yaml +fairseq/examples/MMPT/projects/retri/videoclip/youcook_videoclip.yaml +fairseq/examples/MMPT/projects/task/coin.yaml +fairseq/examples/MMPT/projects/task/coin_videoclip.yaml +fairseq/examples/MMPT/projects/task/crosstask.yaml +fairseq/examples/MMPT/projects/task/crosstask_videoclip.yaml +fairseq/examples/MMPT/projects/task/default.yaml +fairseq/examples/MMPT/projects/task/ft.yaml +fairseq/examples/MMPT/projects/task/how2.yaml +fairseq/examples/MMPT/projects/task/test.yaml +fairseq/examples/MMPT/projects/task/test_coin.yaml +fairseq/examples/MMPT/projects/task/test_coin_videoclip.yaml +fairseq/examples/MMPT/projects/task/test_coin_zs.yaml +fairseq/examples/MMPT/projects/task/test_crosstask.yaml +fairseq/examples/MMPT/projects/task/test_crosstask_videoclip.yaml +fairseq/examples/MMPT/projects/task/test_crosstask_zs.yaml +fairseq/examples/MMPT/projects/task/test_crosstask_zs_videoclip.yaml +fairseq/examples/MMPT/projects/task/test_didemo_zs.yaml +fairseq/examples/MMPT/projects/task/test_vtt.yaml +fairseq/examples/MMPT/projects/task/test_vtt_videoclip.yaml +fairseq/examples/MMPT/projects/task/test_vtt_zs.yaml +fairseq/examples/MMPT/projects/task/test_vttqa.yaml +fairseq/examples/MMPT/projects/task/test_vttqa_videoclip.yaml +fairseq/examples/MMPT/projects/task/test_vttqa_zs.yaml +fairseq/examples/MMPT/projects/task/test_youcook.yaml +fairseq/examples/MMPT/projects/task/test_youcook_videoclip.yaml +fairseq/examples/MMPT/projects/task/test_youcook_zs.yaml +fairseq/examples/MMPT/projects/task/test_youcookcap.yaml +fairseq/examples/MMPT/projects/task/vtt.yaml +fairseq/examples/MMPT/projects/task/vtt_videoclip.yaml +fairseq/examples/MMPT/projects/task/vttqa.yaml +fairseq/examples/MMPT/projects/task/vttqa_videoclip.yaml +fairseq/examples/MMPT/projects/task/youcook.yaml +fairseq/examples/MMPT/projects/task/youcook_videoclip.yaml +fairseq/examples/MMPT/projects/task/youcookcap.yaml +fairseq/examples/MMPT/scripts/text_token_extractor/pretokenization.py +fairseq/examples/MMPT/scripts/text_token_extractor/configs/bert-base-uncased.yaml +fairseq/examples/MMPT/scripts/video_feature_extractor/extract.py +fairseq/examples/MMPT/scripts/video_feature_extractor/model.py +fairseq/examples/MMPT/scripts/video_feature_extractor/pathbuilder.py +fairseq/examples/MMPT/scripts/video_feature_extractor/preprocessing.py +fairseq/examples/MMPT/scripts/video_feature_extractor/random_sequence_shuffler.py +fairseq/examples/MMPT/scripts/video_feature_extractor/shard_feature.py +fairseq/examples/MMPT/scripts/video_feature_extractor/videoreader.py +fairseq/examples/MMPT/scripts/video_feature_extractor/how2/s3d.sh +fairseq/examples/adaptive_span/README.md +fairseq/examples/adaptive_span/__init__.py +fairseq/examples/adaptive_span/adagrad_with_grad_clip.py +fairseq/examples/adaptive_span/adaptive_span_attention.py +fairseq/examples/adaptive_span/adaptive_span_loss.py +fairseq/examples/adaptive_span/adaptive_span_model.py +fairseq/examples/adaptive_span/adaptive_span_model_wrapper.py +fairseq/examples/adaptive_span/truncated_bptt_lm_task.py +fairseq/examples/attention_head_selection/README.md +fairseq/examples/attention_head_selection/src/__init__.py +fairseq/examples/attention_head_selection/src/speech_to_text_head_selection.py +fairseq/examples/attention_head_selection/src/data/__init__.py +fairseq/examples/attention_head_selection/src/data/speech_to_text_dataset_with_domain.py +fairseq/examples/attention_head_selection/src/loss/__init__.py +fairseq/examples/attention_head_selection/src/loss/attention_head_selection.py +fairseq/examples/attention_head_selection/src/models/__init__.py +fairseq/examples/attention_head_selection/src/models/head_selection_s2t_transformer.py +fairseq/examples/attention_head_selection/src/models/head_selection_transformer.py +fairseq/examples/attention_head_selection/src/modules/__init__.py +fairseq/examples/attention_head_selection/src/modules/attn_head_selector.py +fairseq/examples/attention_head_selection/src/modules/head_selection_transformer_layer.py +fairseq/examples/attention_head_selection/src/modules/multihead_attention_selection.py +fairseq/examples/attention_head_selection/src/modules/multihead_functional.py +fairseq/examples/audio_nlp/nlu/README.md +fairseq/examples/audio_nlp/nlu/create_dict_stop.sh +fairseq/examples/audio_nlp/nlu/generate_manifests.py +fairseq/examples/audio_nlp/nlu/configs/nlu_finetuning.yaml +fairseq/examples/backtranslation/README.md +fairseq/examples/backtranslation/deduplicate_lines.py +fairseq/examples/backtranslation/extract_bt_data.py +fairseq/examples/backtranslation/prepare-de-monolingual.sh +fairseq/examples/backtranslation/prepare-wmt18en2de.sh +fairseq/examples/backtranslation/sacrebleu.sh +fairseq/examples/backtranslation/tokenized_bleu.sh +fairseq/examples/bart/README.glue.md +fairseq/examples/bart/README.md +fairseq/examples/bart/README.summarization.md +fairseq/examples/bart/summarize.py +fairseq/examples/byte_level_bpe/README.md +fairseq/examples/byte_level_bpe/get_bitext.py +fairseq/examples/byte_level_bpe/get_data.sh +fairseq/examples/byte_level_bpe/gru_transformer.py +fairseq/examples/camembert/README.md +fairseq/examples/constrained_decoding/README.md +fairseq/examples/constrained_decoding/normalize.py +fairseq/examples/constrained_decoding/tok.py +fairseq/examples/conv_seq2seq/README.md +fairseq/examples/criss/README.md +fairseq/examples/criss/download_and_preprocess_flores_test.sh +fairseq/examples/criss/download_and_preprocess_tatoeba.sh +fairseq/examples/criss/save_encoder.py +fairseq/examples/criss/mining/mine.py +fairseq/examples/criss/mining/mine_example.sh +fairseq/examples/criss/sentence_retrieval/encoder_analysis.py +fairseq/examples/criss/sentence_retrieval/sentence_retrieval_tatoeba.sh +fairseq/examples/criss/unsupervised_mt/eval.sh +fairseq/examples/cross_lingual_language_model/README.md +fairseq/examples/data2vec/README.md +fairseq/examples/data2vec/__init__.py +fairseq/examples/data2vec/fb_convert_beit_cp.py +fairseq/examples/data2vec/config/audio/classification/base_classification.yaml +fairseq/examples/data2vec/config/audio/classification/run_config/slurm_1.yaml +fairseq/examples/data2vec/config/audio/classification/run_config/slurm_1g.yaml +fairseq/examples/data2vec/config/audio/classification/run_config/slurm_2.yaml +fairseq/examples/data2vec/config/audio/pretraining/audioset.yaml +fairseq/examples/data2vec/config/audio/pretraining/base_librispeech.yaml +fairseq/examples/data2vec/config/audio/pretraining/run_config/local.yaml +fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_1.yaml +fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_1_aws.yaml +fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_2.yaml +fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_2_aws.yaml +fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_3.yaml +fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_4.yaml +fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_4_aws.yaml +fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_6_aws.yaml +fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_8_aws.yaml +fairseq/examples/data2vec/config/text/pretraining/base.yaml +fairseq/examples/data2vec/config/text/pretraining/run_config/local.yaml +fairseq/examples/data2vec/config/text/pretraining/run_config/slurm_1_aws.yaml +fairseq/examples/data2vec/config/text/pretraining/run_config/slurm_2.yaml +fairseq/examples/data2vec/config/text/pretraining/run_config/slurm_2_aws.yaml +fairseq/examples/data2vec/config/text/pretraining/run_config/slurm_3.yaml +fairseq/examples/data2vec/config/text/pretraining/run_config/slurm_4.yaml +fairseq/examples/data2vec/config/text/pretraining/run_config/slurm_4_aws.yaml +fairseq/examples/data2vec/config/text/pretraining/run_config/slurm_8_aws.yaml +fairseq/examples/data2vec/config/v2/base_audio_only_task.yaml +fairseq/examples/data2vec/config/v2/base_images_only_task.yaml +fairseq/examples/data2vec/config/v2/base_text_only_task.yaml +fairseq/examples/data2vec/config/v2/huge_images14_only_task.yaml +fairseq/examples/data2vec/config/v2/huge_images_only_task.yaml +fairseq/examples/data2vec/config/v2/large_audio_only_task.yaml +fairseq/examples/data2vec/config/v2/large_images_only_task.yaml +fairseq/examples/data2vec/config/v2/large_text_only_task.yaml +fairseq/examples/data2vec/config/v2/large_text_only_task_pgrp_1M.yaml +fairseq/examples/data2vec/config/v2/run_config/local.yaml +fairseq/examples/data2vec/config/v2/run_config/slurm_1.yaml +fairseq/examples/data2vec/config/v2/run_config/slurm_1_aws.yaml +fairseq/examples/data2vec/config/v2/run_config/slurm_2.yaml +fairseq/examples/data2vec/config/v2/run_config/slurm_2_aws.yaml +fairseq/examples/data2vec/config/v2/run_config/slurm_3.yaml +fairseq/examples/data2vec/config/v2/run_config/slurm_4.yaml +fairseq/examples/data2vec/config/v2/run_config/slurm_4_aws.yaml +fairseq/examples/data2vec/config/v2/run_config/slurm_6_aws.yaml +fairseq/examples/data2vec/config/v2/run_config/slurm_8.yaml +fairseq/examples/data2vec/config/v2/run_config/slurm_8_aws.yaml +fairseq/examples/data2vec/config/v2/text_finetuning/cola.yaml +fairseq/examples/data2vec/config/v2/text_finetuning/mnli.yaml +fairseq/examples/data2vec/config/v2/text_finetuning/mrpc.yaml +fairseq/examples/data2vec/config/v2/text_finetuning/qnli.yaml +fairseq/examples/data2vec/config/v2/text_finetuning/qqp.yaml +fairseq/examples/data2vec/config/v2/text_finetuning/rte.yaml +fairseq/examples/data2vec/config/v2/text_finetuning/sst_2.yaml +fairseq/examples/data2vec/config/v2/text_finetuning/sts_b.yaml +fairseq/examples/data2vec/config/v2/text_finetuning/run_config/local.yaml +fairseq/examples/data2vec/config/vision/finetuning/imagenet.yaml +fairseq/examples/data2vec/config/vision/finetuning/mae_imagenet_clean.yaml +fairseq/examples/data2vec/config/vision/finetuning/mae_imagenet_huge_clean.yaml +fairseq/examples/data2vec/config/vision/finetuning/mae_imagenet_large_clean.yaml +fairseq/examples/data2vec/config/vision/finetuning/run_config/local.yaml +fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_1.yaml +fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_1_aws.yaml +fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_2.yaml +fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_2_aws.yaml +fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_3.yaml +fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_4.yaml +fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_4_aws.yaml +fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_6_aws.yaml +fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_8_aws.yaml +fairseq/examples/data2vec/config/vision/pretraining/base_imagenet.yaml +fairseq/examples/data2vec/config/vision/pretraining/base_imagenet_d2v1.yaml +fairseq/examples/data2vec/config/vision/pretraining/base_mae_imagenet.yaml +fairseq/examples/data2vec/config/vision/pretraining/run_config/local.yaml +fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_1.yaml +fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_1_aws.yaml +fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_2.yaml +fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_2_aws.yaml +fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_3.yaml +fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_4.yaml +fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_4_aws.yaml +fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_6_aws.yaml +fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_8_aws.yaml +fairseq/examples/data2vec/data/__init__.py +fairseq/examples/data2vec/data/add_class_target_dataset.py +fairseq/examples/data2vec/data/image_dataset.py +fairseq/examples/data2vec/data/mae_finetuning_image_dataset.py +fairseq/examples/data2vec/data/mae_image_dataset.py +fairseq/examples/data2vec/data/modality.py +fairseq/examples/data2vec/data/path_dataset.py +fairseq/examples/data2vec/models/__init__.py +fairseq/examples/data2vec/models/audio_classification.py +fairseq/examples/data2vec/models/data2vec2.py +fairseq/examples/data2vec/models/data2vec_audio.py +fairseq/examples/data2vec/models/data2vec_image_classification.py +fairseq/examples/data2vec/models/data2vec_text.py +fairseq/examples/data2vec/models/data2vec_text_classification.py +fairseq/examples/data2vec/models/data2vec_vision.py +fairseq/examples/data2vec/models/mae.py +fairseq/examples/data2vec/models/mae_image_classification.py +fairseq/examples/data2vec/models/utils.py +fairseq/examples/data2vec/models/modalities/__init__.py +fairseq/examples/data2vec/models/modalities/audio.py +fairseq/examples/data2vec/models/modalities/base.py +fairseq/examples/data2vec/models/modalities/images.py +fairseq/examples/data2vec/models/modalities/modules.py +fairseq/examples/data2vec/models/modalities/text.py +fairseq/examples/data2vec/scripts/convert_audioset_labels.py +fairseq/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr.sh +fairseq/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr_nodep.sh +fairseq/examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh +fairseq/examples/data2vec/scripts/text/finetune_all_char_fair_aws_local_lr.sh +fairseq/examples/data2vec/scripts/text/finetune_all_fair.sh +fairseq/examples/data2vec/scripts/text/finetune_all_fair_aws.sh +fairseq/examples/data2vec/scripts/text/finetune_all_fair_aws_local_lr.sh +fairseq/examples/data2vec/scripts/text/finetune_all_fair_aws_lr.sh +fairseq/examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh +fairseq/examples/data2vec/scripts/text/finetune_all_fair_nodep.sh +fairseq/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws.sh +fairseq/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_local_lr.sh +fairseq/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr.sh +fairseq/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr_nopos.sh +fairseq/examples/data2vec/scripts/text/finetune_all_large_fair_aws_local_lr.sh +fairseq/examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh +fairseq/examples/data2vec/scripts/text/finetune_all_large_fair_nodep_aws_local_lr.sh +fairseq/examples/data2vec/scripts/text/finetune_sst2_qnli_sweep_fair_nodep.sh +fairseq/examples/data2vec/scripts/text/glue.py +fairseq/examples/data2vec/scripts/text/glue_lr.py +fairseq/examples/data2vec/scripts/text/unprocess_data.py +fairseq/examples/data2vec/scripts/text/valids.py +fairseq/examples/data2vec/tasks/__init__.py +fairseq/examples/data2vec/tasks/audio_classification.py +fairseq/examples/data2vec/tasks/image_classification.py +fairseq/examples/data2vec/tasks/image_pretraining.py +fairseq/examples/data2vec/tasks/mae_image_classification.py +fairseq/examples/data2vec/tasks/mae_image_pretraining.py +fairseq/examples/data2vec/tasks/multimodal.py +fairseq/examples/discriminative_reranking_nmt/README.md +fairseq/examples/discriminative_reranking_nmt/__init__.py +fairseq/examples/discriminative_reranking_nmt/drnmt_rerank.py +fairseq/examples/discriminative_reranking_nmt/config/deen.yaml +fairseq/examples/discriminative_reranking_nmt/criterions/__init__.py +fairseq/examples/discriminative_reranking_nmt/criterions/discriminative_reranking_criterion.py +fairseq/examples/discriminative_reranking_nmt/models/__init__.py +fairseq/examples/discriminative_reranking_nmt/models/discriminative_reranking_model.py +fairseq/examples/discriminative_reranking_nmt/scripts/prep_data.py +fairseq/examples/discriminative_reranking_nmt/tasks/__init__.py +fairseq/examples/discriminative_reranking_nmt/tasks/discriminative_reranking_task.py +fairseq/examples/emotion_conversion/README.md +fairseq/examples/emotion_conversion/requirements.txt +fairseq/examples/emotion_conversion/synthesize.py +fairseq/examples/emotion_conversion/emotion_models/__init__.py +fairseq/examples/emotion_conversion/emotion_models/duration_predictor.py +fairseq/examples/emotion_conversion/emotion_models/duration_predictor.yaml +fairseq/examples/emotion_conversion/emotion_models/pitch_predictor.py +fairseq/examples/emotion_conversion/emotion_models/pitch_predictor.yaml +fairseq/examples/emotion_conversion/emotion_models/utils.py +fairseq/examples/emotion_conversion/fairseq_models/__init__.py +fairseq/examples/emotion_conversion/preprocess/__init__.py +fairseq/examples/emotion_conversion/preprocess/build_hifigan_manifest.py +fairseq/examples/emotion_conversion/preprocess/build_translation_manifests.py +fairseq/examples/emotion_conversion/preprocess/create_core_manifest.py +fairseq/examples/emotion_conversion/preprocess/extract_f0.py +fairseq/examples/emotion_conversion/preprocess/process_km.py +fairseq/examples/emotion_conversion/preprocess/split_emov_km_tsv_by_uttid.py +fairseq/examples/emotion_conversion/preprocess/split_km.py +fairseq/examples/emotion_conversion/preprocess/split_km_tsv.py +fairseq/examples/fast_noisy_channel/README.md +fairseq/examples/fast_noisy_channel/__init__.py +fairseq/examples/fast_noisy_channel/noisy_channel_beam_search.py +fairseq/examples/fast_noisy_channel/noisy_channel_sequence_generator.py +fairseq/examples/fast_noisy_channel/noisy_channel_translation.py +fairseq/examples/flores101/README.md +fairseq/examples/flores101/flores_logo.png +fairseq/examples/fully_sharded_data_parallel/README.md +fairseq/examples/gottbert/README.md +fairseq/examples/hubert/README.md +fairseq/examples/hubert/measure_teacher_quality.py +fairseq/examples/hubert/update_ckpt.py +fairseq/examples/hubert/config/decode/infer_fsqlm.yaml +fairseq/examples/hubert/config/decode/infer_kenlm.yaml +fairseq/examples/hubert/config/decode/infer_viterbi.yaml +fairseq/examples/hubert/config/decode/ax_sweep/ngram.yaml +fairseq/examples/hubert/config/decode/ax_sweep/transformer.yaml +fairseq/examples/hubert/config/decode/run/submitit_slurm.yaml +fairseq/examples/hubert/config/decode/run/submitit_slurm_8gpu.yaml +fairseq/examples/hubert/config/finetune/base_10h.yaml +fairseq/examples/hubert/config/finetune/ckpt/it1.yaml +fairseq/examples/hubert/config/finetune/lm/ls_4gram.yaml +fairseq/examples/hubert/config/finetune/run/submitit_reg.yaml +fairseq/examples/hubert/config/pretrain/hubert_base_librispeech.yaml +fairseq/examples/hubert/config/pretrain/hubert_large_librivox.yaml +fairseq/examples/hubert/config/pretrain/hubert_xlarge_librivox.yaml +fairseq/examples/hubert/config/pretrain/data/iter1.yaml +fairseq/examples/hubert/config/pretrain/data/iter2.yaml +fairseq/examples/hubert/config/pretrain/run/submitit_reg.yaml +fairseq/examples/hubert/simple_kmeans/README.md +fairseq/examples/hubert/simple_kmeans/dump_hubert_feature.py +fairseq/examples/hubert/simple_kmeans/dump_hubert_feature_s2t.py +fairseq/examples/hubert/simple_kmeans/dump_km_label.py +fairseq/examples/hubert/simple_kmeans/dump_mfcc_feature.py +fairseq/examples/hubert/simple_kmeans/dump_w2v2_feature.py +fairseq/examples/hubert/simple_kmeans/feature_utils.py +fairseq/examples/hubert/simple_kmeans/learn_kmeans.py +fairseq/examples/hubert/tests/6313-76958-0021.flac +fairseq/examples/hubert/tests/sample.base.L9.km500.km +fairseq/examples/hubert/tests/sample.base.L9.len +fairseq/examples/hubert/tests/sample.base.L9.npy +fairseq/examples/hubert/tests/sample.large.L20.len +fairseq/examples/hubert/tests/sample.large.L20.npy +fairseq/examples/hubert/tests/sample.large.hypo.word +fairseq/examples/hubert/tests/sample.xlarge.L30.len +fairseq/examples/hubert/tests/sample.xlarge.L30.npy +fairseq/examples/hubert/tests/sample.xlarge.hypo.word +fairseq/examples/hubert/tests/test_feature_and_unit.sh +fairseq/examples/hubert/tests/test_finetuned_asr.sh +fairseq/examples/joint_alignment_translation/README.md +fairseq/examples/joint_alignment_translation/prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh +fairseq/examples/language_model/README.adaptive_inputs.md +fairseq/examples/language_model/README.conv.md +fairseq/examples/language_model/README.md +fairseq/examples/language_model/prepare-wikitext-103.sh +fairseq/examples/laser/README.md +fairseq/examples/laser/laser_src/__init__.py +fairseq/examples/laser/laser_src/laser_lstm.py +fairseq/examples/laser/laser_src/laser_task.py +fairseq/examples/laser/laser_src/laser_transformer.py +fairseq/examples/laser/laser_src/multitask_data_utils.py +fairseq/examples/latent_depth/README.md +fairseq/examples/latent_depth/latent_depth_src/__init__.py +fairseq/examples/latent_depth/latent_depth_src/multilingual_translation_latent_depth.py +fairseq/examples/latent_depth/latent_depth_src/loss/__init__.py +fairseq/examples/latent_depth/latent_depth_src/loss/latent_depth.py +fairseq/examples/latent_depth/latent_depth_src/models/__init__.py +fairseq/examples/latent_depth/latent_depth_src/models/latent_multilingual_transformer.py +fairseq/examples/latent_depth/latent_depth_src/models/latent_transformer.py +fairseq/examples/latent_depth/latent_depth_src/modules/__init__.py +fairseq/examples/latent_depth/latent_depth_src/modules/latent_layers.py +fairseq/examples/layerdrop/README.md +fairseq/examples/linformer/README.md +fairseq/examples/linformer/linformer_src/__init__.py +fairseq/examples/linformer/linformer_src/models/__init__.py +fairseq/examples/linformer/linformer_src/models/linformer_roberta.py +fairseq/examples/linformer/linformer_src/modules/__init__.py +fairseq/examples/linformer/linformer_src/modules/linformer_sentence_encoder.py +fairseq/examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py +fairseq/examples/linformer/linformer_src/modules/multihead_linear_attention.py +fairseq/examples/m2m_100/README.md +fairseq/examples/m2m_100/install_dependecies.sh +fairseq/examples/m2m_100/tok.sh +fairseq/examples/m2m_100/process_data/clean_histogram.py +fairseq/examples/m2m_100/process_data/dedup_data.py +fairseq/examples/m2m_100/process_data/remove_too_much_punc.py +fairseq/examples/m2m_100/tokenizers/README.md +fairseq/examples/m2m_100/tokenizers/seg_ja.sh +fairseq/examples/m2m_100/tokenizers/seg_ko.sh +fairseq/examples/m2m_100/tokenizers/tokenize_indic.py +fairseq/examples/m2m_100/tokenizers/tokenize_thai.py +fairseq/examples/m2m_100/tokenizers/tokenize_zh.py +fairseq/examples/m2m_100/tokenizers/tokenizer_ar.sh +fairseq/examples/m2m_100/tokenizers/thirdparty/.gitignore +fairseq/examples/mbart/README.md +fairseq/examples/megatron_11b/README.md +fairseq/examples/megatron_11b/detok.py +fairseq/examples/mms/MODEL_CARD.md +fairseq/examples/mms/README.md +fairseq/examples/mms/asr/config/infer_common.yaml +fairseq/examples/mms/asr/infer/example_infer_adapter.sh +fairseq/examples/mms/asr/infer/mms_infer.py +fairseq/examples/mms/asr/tutorial/MMS_ASR_Inference_Colab.ipynb +fairseq/examples/mms/data_prep/README.md +fairseq/examples/mms/data_prep/align_and_segment.py +fairseq/examples/mms/data_prep/align_utils.py +fairseq/examples/mms/data_prep/norm_config.py +fairseq/examples/mms/data_prep/punctuations.lst +fairseq/examples/mms/data_prep/text_normalization.py +fairseq/examples/mms/lid/infer.py +fairseq/examples/mms/lid/tutorial/MMS_LID_Inference_Colab.ipynb +fairseq/examples/mms/lid_rerank/README.md +fairseq/examples/mms/lid_rerank/cer_langs.txt +fairseq/examples/mms/lid_rerank/requirements.txt +fairseq/examples/mms/lid_rerank/mala/infer.py +fairseq/examples/mms/lid_rerank/mms/make_parallel_single_runs.py +fairseq/examples/mms/lid_rerank/mms/merge_by_lang.py +fairseq/examples/mms/lid_rerank/mms/prep_wav_list.py +fairseq/examples/mms/lid_rerank/mms/run_single_lang.py +fairseq/examples/mms/lid_rerank/mms/split_by_lang.py +fairseq/examples/mms/lid_rerank/mms-zs/falign.py +fairseq/examples/mms/lid_rerank/mms-zs/lib.py +fairseq/examples/mms/lid_rerank/mms-zs/uromanize.py +fairseq/examples/mms/lid_rerank/nllb/infer.py +fairseq/examples/mms/lid_rerank/rerank/rerank.py +fairseq/examples/mms/lid_rerank/rerank/tune_coefficients.py +fairseq/examples/mms/lid_rerank/whisper/infer_asr.py +fairseq/examples/mms/lid_rerank/whisper/infer_lid.py +fairseq/examples/mms/lid_rerank/whisper/lid_mapping.txt +fairseq/examples/mms/misc/get_sample_size.py +fairseq/examples/mms/tts/infer.py +fairseq/examples/mms/tts/tutorial/MMS_TTS_Inference_Colab.ipynb +fairseq/examples/mms/zero_shot/README.md +fairseq/examples/moe_lm/README.md +fairseq/examples/moe_lm/data_card.md +fairseq/examples/moe_lm/model_card.md +fairseq/examples/mr_hubert/README.md +fairseq/examples/mr_hubert/decode.sh +fairseq/examples/mr_hubert/finetune.sh +fairseq/examples/mr_hubert/train.sh +fairseq/examples/mr_hubert/config/decode/infer.yaml +fairseq/examples/mr_hubert/config/decode/infer_lm.yaml +fairseq/examples/mr_hubert/config/decode/run/submitit_slurm.yaml +fairseq/examples/mr_hubert/config/decode/run/submitit_slurm_8gpu.yaml +fairseq/examples/mr_hubert/config/finetune/base_100h.yaml +fairseq/examples/mr_hubert/config/finetune/base_100h_large.yaml +fairseq/examples/mr_hubert/config/finetune/base_10h.yaml +fairseq/examples/mr_hubert/config/finetune/base_10h_large.yaml +fairseq/examples/mr_hubert/config/finetune/base_1h.yaml +fairseq/examples/mr_hubert/config/finetune/base_1h_large.yaml +fairseq/examples/mr_hubert/config/pretrain/mrhubert_base_librispeech.yaml +fairseq/examples/mr_hubert/config/pretrain/mrhubert_large_librilight.yaml +fairseq/examples/mr_hubert/config/pretrain/run/submitit_reg.yaml +fairseq/examples/mr_hubert/simple_kmeans/README.md +fairseq/examples/mr_hubert/simple_kmeans/dump_hubert_feature.py +fairseq/examples/mr_hubert/simple_kmeans/dump_hubert_feature_s2t.py +fairseq/examples/mr_hubert/simple_kmeans/dump_km_label.py +fairseq/examples/mr_hubert/simple_kmeans/dump_mfcc_feature.py +fairseq/examples/mr_hubert/simple_kmeans/dump_w2v2_feature.py +fairseq/examples/mr_hubert/simple_kmeans/feature_utils.py +fairseq/examples/mr_hubert/simple_kmeans/learn_kmeans.py +fairseq/examples/multilingual/ML50_langs.txt +fairseq/examples/multilingual/README.md +fairseq/examples/multilingual/finetune_multilingual_model.sh +fairseq/examples/multilingual/multilingual_fairseq_gen.sh +fairseq/examples/multilingual/train_multilingual_model.sh +fairseq/examples/multilingual/data_scripts/README.md +fairseq/examples/multilingual/data_scripts/binarize.py +fairseq/examples/multilingual/data_scripts/check_iswlt_test_data.py +fairseq/examples/multilingual/data_scripts/check_self_overlaps.py +fairseq/examples/multilingual/data_scripts/check_valid_test_overlaps.py +fairseq/examples/multilingual/data_scripts/dedup_all.py +fairseq/examples/multilingual/data_scripts/download_ML50_v1.sh +fairseq/examples/multilingual/data_scripts/download_af_xh.sh +fairseq/examples/multilingual/data_scripts/download_flores_data.sh +fairseq/examples/multilingual/data_scripts/download_iitb.sh +fairseq/examples/multilingual/data_scripts/download_iwslt_and_extract.sh +fairseq/examples/multilingual/data_scripts/download_lotus.sh +fairseq/examples/multilingual/data_scripts/download_ted_and_extract.py +fairseq/examples/multilingual/data_scripts/download_wat19_my.sh +fairseq/examples/multilingual/data_scripts/download_wmt19_and_before.py +fairseq/examples/multilingual/data_scripts/download_wmt20.sh +fairseq/examples/multilingual/data_scripts/preprocess_ML50_v1.sh +fairseq/examples/multilingual/data_scripts/remove_valid_test_in_train.py +fairseq/examples/multilingual/data_scripts/requirement.txt +fairseq/examples/multilingual/data_scripts/utils/dedup.py +fairseq/examples/multilingual/data_scripts/utils/fasttext_multi_filter.py +fairseq/examples/multilingual/data_scripts/utils/strip_sgm.sh +fairseq/examples/noisychannel/README.md +fairseq/examples/noisychannel/__init__.py +fairseq/examples/noisychannel/rerank.py +fairseq/examples/noisychannel/rerank_generate.py +fairseq/examples/noisychannel/rerank_options.py +fairseq/examples/noisychannel/rerank_score_bw.py +fairseq/examples/noisychannel/rerank_score_lm.py +fairseq/examples/noisychannel/rerank_tune.py +fairseq/examples/noisychannel/rerank_utils.py +fairseq/examples/nonautoregressive_translation/README.md +fairseq/examples/nonautoregressive_translation/scripts.md +fairseq/examples/normformer/README.md +fairseq/examples/normformer/train_lm.sh +fairseq/examples/operators/alignment_train_cpu.cpp +fairseq/examples/operators/alignment_train_cuda.cpp +fairseq/examples/operators/alignment_train_cuda.h +fairseq/examples/operators/alignment_train_kernel.cu +fairseq/examples/operators/utils.h +fairseq/examples/paraphraser/README.md +fairseq/examples/paraphraser/paraphrase.py +fairseq/examples/pay_less_attention_paper/README.md +fairseq/examples/pointer_generator/README.md +fairseq/examples/pointer_generator/README.xsum.md +fairseq/examples/pointer_generator/postprocess.py +fairseq/examples/pointer_generator/preprocess.py +fairseq/examples/pointer_generator/pointer_generator_src/__init__.py +fairseq/examples/pointer_generator/pointer_generator_src/transformer_pg.py +fairseq/examples/quant_noise/README.md +fairseq/examples/quant_noise/transformer_quantization_config.yaml +fairseq/examples/roberta/README.custom_classification.md +fairseq/examples/roberta/README.glue.md +fairseq/examples/roberta/README.md +fairseq/examples/roberta/README.pretraining.md +fairseq/examples/roberta/README.race.md +fairseq/examples/roberta/multiprocessing_bpe_encoder.py +fairseq/examples/roberta/preprocess_GLUE_tasks.sh +fairseq/examples/roberta/preprocess_RACE.py +fairseq/examples/roberta/preprocess_RACE.sh +fairseq/examples/roberta/commonsense_qa/README.md +fairseq/examples/roberta/commonsense_qa/__init__.py +fairseq/examples/roberta/commonsense_qa/commonsense_qa_task.py +fairseq/examples/roberta/commonsense_qa/download_cqa_data.sh +fairseq/examples/roberta/config/finetuning/cola.yaml +fairseq/examples/roberta/config/finetuning/mnli.yaml +fairseq/examples/roberta/config/finetuning/mrpc.yaml +fairseq/examples/roberta/config/finetuning/qnli.yaml +fairseq/examples/roberta/config/finetuning/qqp.yaml +fairseq/examples/roberta/config/finetuning/rte.yaml +fairseq/examples/roberta/config/finetuning/sst_2.yaml +fairseq/examples/roberta/config/finetuning/sts_b.yaml +fairseq/examples/roberta/config/finetuning/run_config/local.yaml +fairseq/examples/roberta/config/finetuning/run_config/slurm_1g.yaml +fairseq/examples/roberta/config/finetuning/run_config/slurm_1g_aws.yaml +fairseq/examples/roberta/config/pretraining/base.yaml +fairseq/examples/roberta/config/pretraining/run_config/local.yaml +fairseq/examples/roberta/config/pretraining/run_config/slurm_2.yaml +fairseq/examples/roberta/config/pretraining/run_config/slurm_2_aws.yaml +fairseq/examples/roberta/config/pretraining/run_config/slurm_3.yaml +fairseq/examples/roberta/config/pretraining/run_config/slurm_4.yaml +fairseq/examples/roberta/fb_multilingual/README.multilingual.pretraining.md +fairseq/examples/roberta/wsc/README.md +fairseq/examples/roberta/wsc/__init__.py +fairseq/examples/roberta/wsc/wsc_criterion.py +fairseq/examples/roberta/wsc/wsc_task.py +fairseq/examples/roberta/wsc/wsc_utils.py +fairseq/examples/rxf/README.md +fairseq/examples/rxf/__init__.py +fairseq/examples/rxf/rxf_src/__init__.py +fairseq/examples/rxf/rxf_src/label_smoothed_cross_entropy_r3f.py +fairseq/examples/rxf/rxf_src/sentence_prediction_r3f.py +fairseq/examples/scaling_nmt/README.md +fairseq/examples/shuffled_word_order/README.finetuning.md +fairseq/examples/shuffled_word_order/README.md +fairseq/examples/simultaneous_translation/README.md +fairseq/examples/simultaneous_translation/__init__.py +fairseq/examples/simultaneous_translation/docs/ende-mma.md +fairseq/examples/simultaneous_translation/docs/enja-waitk.md +fairseq/examples/simultaneous_translation/eval/agents/simul_t2t_enja.py +fairseq/examples/simultaneous_translation/models/__init__.py +fairseq/examples/simultaneous_translation/models/convtransformer_simul_trans.py +fairseq/examples/simultaneous_translation/models/transformer_monotonic_attention.py +fairseq/examples/simultaneous_translation/modules/__init__.py +fairseq/examples/simultaneous_translation/modules/fixed_pre_decision.py +fairseq/examples/simultaneous_translation/modules/monotonic_multihead_attention.py +fairseq/examples/simultaneous_translation/modules/monotonic_transformer_layer.py +fairseq/examples/simultaneous_translation/tests/test_alignment_train.py +fairseq/examples/simultaneous_translation/tests/test_text_models.py +fairseq/examples/simultaneous_translation/utils/__init__.py +fairseq/examples/simultaneous_translation/utils/functions.py +fairseq/examples/simultaneous_translation/utils/monotonic_attention.py +fairseq/examples/simultaneous_translation/utils/p_choose_strategy.py +fairseq/examples/speech_recognition/README.md +fairseq/examples/speech_recognition/__init__.py +fairseq/examples/speech_recognition/infer.py +fairseq/examples/speech_recognition/w2l_decoder.py +fairseq/examples/speech_recognition/criterions/ASG_loss.py +fairseq/examples/speech_recognition/criterions/__init__.py +fairseq/examples/speech_recognition/criterions/cross_entropy_acc.py +fairseq/examples/speech_recognition/data/__init__.py +fairseq/examples/speech_recognition/data/asr_dataset.py +fairseq/examples/speech_recognition/data/collaters.py +fairseq/examples/speech_recognition/data/data_utils.py +fairseq/examples/speech_recognition/data/replabels.py +fairseq/examples/speech_recognition/datasets/asr_prep_json.py +fairseq/examples/speech_recognition/datasets/prepare-librispeech.sh +fairseq/examples/speech_recognition/kaldi/__init__.py +fairseq/examples/speech_recognition/kaldi/add-self-loop-simple.cc +fairseq/examples/speech_recognition/kaldi/kaldi_decoder.py +fairseq/examples/speech_recognition/kaldi/kaldi_initializer.py +fairseq/examples/speech_recognition/kaldi/config/kaldi_initializer.yaml +fairseq/examples/speech_recognition/models/__init__.py +fairseq/examples/speech_recognition/models/vggtransformer.py +fairseq/examples/speech_recognition/models/w2l_conv_glu_enc.py +fairseq/examples/speech_recognition/new/README.md +fairseq/examples/speech_recognition/new/__init__.py +fairseq/examples/speech_recognition/new/infer.py +fairseq/examples/speech_recognition/new/conf/infer.yaml +fairseq/examples/speech_recognition/new/conf/hydra/sweeper/ax.yaml +fairseq/examples/speech_recognition/new/conf/hydra/sweeper/ax_sil.yaml +fairseq/examples/speech_recognition/new/conf/run_config/fb_slurm_1.yaml +fairseq/examples/speech_recognition/new/conf/run_config/fb_slurm_2g.yaml +fairseq/examples/speech_recognition/new/decoders/__init__.py +fairseq/examples/speech_recognition/new/decoders/base_decoder.py +fairseq/examples/speech_recognition/new/decoders/decoder.py +fairseq/examples/speech_recognition/new/decoders/decoder_config.py +fairseq/examples/speech_recognition/new/decoders/flashlight_decoder.py +fairseq/examples/speech_recognition/new/decoders/viterbi_decoder.py +fairseq/examples/speech_recognition/tasks/__init__.py +fairseq/examples/speech_recognition/tasks/speech_recognition.py +fairseq/examples/speech_recognition/utils/wer_utils.py +fairseq/examples/speech_synthesis/README.md +fairseq/examples/speech_synthesis/__init__.py +fairseq/examples/speech_synthesis/data_utils.py +fairseq/examples/speech_synthesis/generate_waveform.py +fairseq/examples/speech_synthesis/utils.py +fairseq/examples/speech_synthesis/docs/common_voice_example.md +fairseq/examples/speech_synthesis/docs/ljspeech_example.md +fairseq/examples/speech_synthesis/docs/vctk_example.md +fairseq/examples/speech_synthesis/evaluation/__init__.py +fairseq/examples/speech_synthesis/evaluation/eval_asr.py +fairseq/examples/speech_synthesis/evaluation/eval_f0.py +fairseq/examples/speech_synthesis/evaluation/eval_sp.py +fairseq/examples/speech_synthesis/evaluation/get_eval_manifest.py +fairseq/examples/speech_synthesis/preprocessing/__init__.py +fairseq/examples/speech_synthesis/preprocessing/denoise_and_vad_audio.py +fairseq/examples/speech_synthesis/preprocessing/get_common_voice_audio_manifest.py +fairseq/examples/speech_synthesis/preprocessing/get_feature_manifest.py +fairseq/examples/speech_synthesis/preprocessing/get_ljspeech_audio_manifest.py +fairseq/examples/speech_synthesis/preprocessing/get_speaker_embedding.py +fairseq/examples/speech_synthesis/preprocessing/get_vctk_audio_manifest.py +fairseq/examples/speech_synthesis/preprocessing/denoiser/__init__.py +fairseq/examples/speech_synthesis/preprocessing/denoiser/demucs.py +fairseq/examples/speech_synthesis/preprocessing/denoiser/pretrained.py +fairseq/examples/speech_synthesis/preprocessing/denoiser/resample.py +fairseq/examples/speech_synthesis/preprocessing/denoiser/utils.py +fairseq/examples/speech_synthesis/preprocessing/speaker_embedder/__init__.py +fairseq/examples/speech_synthesis/preprocessing/vad/__init__.py +fairseq/examples/speech_text_joint_to_text/README.md +fairseq/examples/speech_text_joint_to_text/__init__.py +fairseq/examples/speech_text_joint_to_text/configs/mustc_noise.list +fairseq/examples/speech_text_joint_to_text/criterions/__init__.py +fairseq/examples/speech_text_joint_to_text/criterions/multi_modality_compound.py +fairseq/examples/speech_text_joint_to_text/criterions/multi_modality_cross_entropy.py +fairseq/examples/speech_text_joint_to_text/criterions/text_guide_cross_entropy_acc.py +fairseq/examples/speech_text_joint_to_text/data/pair_denoising_dataset.py +fairseq/examples/speech_text_joint_to_text/docs/ende-mustc.md +fairseq/examples/speech_text_joint_to_text/docs/iwslt2021.md +fairseq/examples/speech_text_joint_to_text/docs/pre-training.md +fairseq/examples/speech_text_joint_to_text/models/__init__.py +fairseq/examples/speech_text_joint_to_text/models/joint_speech_text_pretrain_transformer.py +fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputtransformer.py +fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputwavtransformer.py +fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputxmtransformer.py +fairseq/examples/speech_text_joint_to_text/scripts/convert_model.py +fairseq/examples/speech_text_joint_to_text/scripts/g2p_encode.py +fairseq/examples/speech_text_joint_to_text/tasks/__init__.py +fairseq/examples/speech_text_joint_to_text/tasks/pair_denoising.py +fairseq/examples/speech_text_joint_to_text/tasks/speech_text_denoise_pretrain.py +fairseq/examples/speech_text_joint_to_text/tasks/speech_text_joint.py +fairseq/examples/speech_to_speech/README.md +fairseq/examples/speech_to_speech/__init__.py +fairseq/examples/speech_to_speech/generate_waveform_from_code.py +fairseq/examples/speech_to_speech/asr_bleu/README.md +fairseq/examples/speech_to_speech/asr_bleu/__init__.py +fairseq/examples/speech_to_speech/asr_bleu/asr_model_cfgs.json +fairseq/examples/speech_to_speech/asr_bleu/compute_asr_bleu.py +fairseq/examples/speech_to_speech/asr_bleu/requirements.txt +fairseq/examples/speech_to_speech/asr_bleu/utils.py +fairseq/examples/speech_to_speech/benchmarking/README.md +fairseq/examples/speech_to_speech/benchmarking/core.py +fairseq/examples/speech_to_speech/benchmarking/data_utils.py +fairseq/examples/speech_to_speech/benchmarking/get_metrics.py +fairseq/examples/speech_to_speech/benchmarking/configs/2StageS2ST.yaml +fairseq/examples/speech_to_speech/benchmarking/configs/3StageS2ST.yaml +fairseq/examples/speech_to_speech/benchmarking/configs/DirectS2U.yaml +fairseq/examples/speech_to_speech/benchmarking/configs/S2T.yaml +fairseq/examples/speech_to_speech/docs/data_augmentation.md +fairseq/examples/speech_to_speech/docs/direct_s2st_discrete_units.md +fairseq/examples/speech_to_speech/docs/enhanced_direct_s2st_discrete_units.md +fairseq/examples/speech_to_speech/docs/textless_s2st_real_data.md +fairseq/examples/speech_to_speech/preprocessing/__init__.py +fairseq/examples/speech_to_speech/preprocessing/data_utils.py +fairseq/examples/speech_to_speech/preprocessing/prep_s2spect_data.py +fairseq/examples/speech_to_speech/preprocessing/prep_s2ut_data.py +fairseq/examples/speech_to_speech/preprocessing/prep_sn_data.py +fairseq/examples/speech_to_speech/preprocessing/prep_sn_output_data.py +fairseq/examples/speech_to_speech/unity/__init__.py +fairseq/examples/speech_to_speech/unity/sequence_generator.py +fairseq/examples/speech_to_speech/unity/sequence_generator_multi_decoder.py +fairseq/examples/speech_to_text/README.md +fairseq/examples/speech_to_text/data_utils.py +fairseq/examples/speech_to_text/prep_covost_data.py +fairseq/examples/speech_to_text/prep_librispeech_data.py +fairseq/examples/speech_to_text/prep_mtedx_data.py +fairseq/examples/speech_to_text/prep_mustc_data.py +fairseq/examples/speech_to_text/seg_mustc_data.py +fairseq/examples/speech_to_text/docs/covost_example.md +fairseq/examples/speech_to_text/docs/librispeech_example.md +fairseq/examples/speech_to_text/docs/mtedx_example.md +fairseq/examples/speech_to_text/docs/mustc_example.md +fairseq/examples/speech_to_text/docs/simulst_mustc_example.md +fairseq/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py +fairseq/examples/stories/README.md +fairseq/examples/textless_nlp/dgslm/README.md +fairseq/examples/textless_nlp/dgslm/create_code_file.py +fairseq/examples/textless_nlp/dgslm/dgslm_utils.py +fairseq/examples/textless_nlp/dgslm/sample_speech_dlm.py +fairseq/examples/textless_nlp/dgslm/hubert_fisher/README.md +fairseq/examples/textless_nlp/dgslm/vocoder_hifigan/README.md +fairseq/examples/textless_nlp/dgslm/vocoder_hifigan/generate_stereo_waveform.py +fairseq/examples/textless_nlp/gslm/README.md +fairseq/examples/textless_nlp/gslm/metrics/README.md +fairseq/examples/textless_nlp/gslm/metrics/abx_metrics/README.md +fairseq/examples/textless_nlp/gslm/metrics/abx_metrics/dump_abx_feats.py +fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/README.md +fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/continuation_eval.py +fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/ppx.py +fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/self_auto_bleu.py +fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/misc/bleu_utils.py +fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/misc/cut_as.py +fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/misc/dict.ltr.txt +fairseq/examples/textless_nlp/gslm/speech2unit/README.md +fairseq/examples/textless_nlp/gslm/speech2unit/__init__.py +fairseq/examples/textless_nlp/gslm/speech2unit/clustering/__init__.py +fairseq/examples/textless_nlp/gslm/speech2unit/clustering/cluster_kmeans.py +fairseq/examples/textless_nlp/gslm/speech2unit/clustering/dump_feats.py +fairseq/examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py +fairseq/examples/textless_nlp/gslm/speech2unit/clustering/utils.py +fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/cpc_feature_reader.py +fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/hubert_feature_reader.py +fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/logmel_feature_reader.py +fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/utils.py +fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/w2v2_feature_reader.py +fairseq/examples/textless_nlp/gslm/tools/README.md +fairseq/examples/textless_nlp/gslm/tools/resynthesize_speech.py +fairseq/examples/textless_nlp/gslm/ulm/README.md +fairseq/examples/textless_nlp/gslm/ulm/sample.py +fairseq/examples/textless_nlp/gslm/unit2speech/README.md +fairseq/examples/textless_nlp/gslm/unit2speech/convert_to_16k.py +fairseq/examples/textless_nlp/gslm/unit2speech/glow.py +fairseq/examples/textless_nlp/gslm/unit2speech/multiproc.py +fairseq/examples/textless_nlp/gslm/unit2speech/synthesize_audio_from_units.py +fairseq/examples/textless_nlp/gslm/unit2speech/tts_data.py +fairseq/examples/textless_nlp/gslm/unit2speech/utils.py +fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/__init__.py +fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/audio_processing.py +fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cleaners.py +fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cmudict.py +fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/layers.py +fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/model.py +fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/numbers.py +fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/stft.py +fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/symbols.py +fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/text.py +fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/utils.py +fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/waveglow_denoiser.py +fairseq/examples/textless_nlp/pgslm/README.md +fairseq/examples/textless_nlp/pgslm/data_utils.py +fairseq/examples/textless_nlp/pgslm/generate_waveform.py +fairseq/examples/textless_nlp/pgslm/inference_dataset.py +fairseq/examples/textless_nlp/pgslm/naive_decoder.py +fairseq/examples/textless_nlp/pgslm/prepare_dataset.py +fairseq/examples/textless_nlp/pgslm/preprocess_f0.py +fairseq/examples/textless_nlp/pgslm/quantize_f0.py +fairseq/examples/textless_nlp/pgslm/truncated_laplace.py +fairseq/examples/textless_nlp/pgslm/eval/__init__.py +fairseq/examples/textless_nlp/pgslm/eval/cont_metrics.py +fairseq/examples/textless_nlp/pgslm/sample/__init__.py +fairseq/examples/textless_nlp/pgslm/sample/sample.py +fairseq/examples/textless_nlp/pgslm/scripts/join_units_manifest.py +fairseq/examples/textless_nlp/pgslm/scripts/prepare_data.sh +fairseq/examples/textless_nlp/pgslm/scripts/prepare_f0_quantization.sh +fairseq/examples/textless_nlp/speech-resynth/README.md +fairseq/examples/textless_nlp/speech-resynth/img/fig.png +fairseq/examples/translation/README.md +fairseq/examples/translation/prepare-iwslt14.sh +fairseq/examples/translation/prepare-iwslt17-multilingual.sh +fairseq/examples/translation/prepare-wmt14en2de.sh +fairseq/examples/translation/prepare-wmt14en2fr.sh +fairseq/examples/translation_moe/README.md +fairseq/examples/translation_moe/score.py +fairseq/examples/translation_moe/translation_moe_src/__init__.py +fairseq/examples/translation_moe/translation_moe_src/logsumexp_moe.py +fairseq/examples/translation_moe/translation_moe_src/mean_pool_gating_network.py +fairseq/examples/translation_moe/translation_moe_src/translation_moe.py +fairseq/examples/truncated_bptt/README.md +fairseq/examples/truncated_bptt/__init__.py +fairseq/examples/truncated_bptt/transformer_xl_model.py +fairseq/examples/truncated_bptt/truncated_bptt_lm_task.py +fairseq/examples/unsupervised_quality_estimation/README.md +fairseq/examples/unsupervised_quality_estimation/aggregate_scores.py +fairseq/examples/unsupervised_quality_estimation/meteor.py +fairseq/examples/unsupervised_quality_estimation/repeat_lines.py +fairseq/examples/wav2vec/README.md +fairseq/examples/wav2vec/__init__.py +fairseq/examples/wav2vec/libri_labels.py +fairseq/examples/wav2vec/vq-wav2vec_featurize.py +fairseq/examples/wav2vec/wav2vec_featurize.py +fairseq/examples/wav2vec/wav2vec_manifest.py +fairseq/examples/wav2vec/config/finetuning/base_100h.yaml +fairseq/examples/wav2vec/config/finetuning/base_10h.yaml +fairseq/examples/wav2vec/config/finetuning/base_10m.yaml +fairseq/examples/wav2vec/config/finetuning/base_1h.yaml +fairseq/examples/wav2vec/config/finetuning/base_960h.yaml +fairseq/examples/wav2vec/config/finetuning/vox_100h.yaml +fairseq/examples/wav2vec/config/finetuning/vox_100h_2.yaml +fairseq/examples/wav2vec/config/finetuning/vox_100h_2_aws.yaml +fairseq/examples/wav2vec/config/finetuning/vox_100h_3.yaml +fairseq/examples/wav2vec/config/finetuning/vox_10h.yaml +fairseq/examples/wav2vec/config/finetuning/vox_10h_2.yaml +fairseq/examples/wav2vec/config/finetuning/vox_10h_2_aws.yaml +fairseq/examples/wav2vec/config/finetuning/vox_10h_aws.yaml +fairseq/examples/wav2vec/config/finetuning/vox_10h_aws_v100.yaml +fairseq/examples/wav2vec/config/finetuning/vox_10m.yaml +fairseq/examples/wav2vec/config/finetuning/vox_10m_2.yaml +fairseq/examples/wav2vec/config/finetuning/vox_10m_2_aws.yaml +fairseq/examples/wav2vec/config/finetuning/vox_10m_3.yaml +fairseq/examples/wav2vec/config/finetuning/vox_1h.yaml +fairseq/examples/wav2vec/config/finetuning/vox_1h_2.yaml +fairseq/examples/wav2vec/config/finetuning/vox_1h_2_aws.yaml +fairseq/examples/wav2vec/config/finetuning/vox_1h_3.yaml +fairseq/examples/wav2vec/config/finetuning/vox_1h_4.yaml +fairseq/examples/wav2vec/config/finetuning/vox_1h_aws.yaml +fairseq/examples/wav2vec/config/finetuning/vox_960h.yaml +fairseq/examples/wav2vec/config/finetuning/vox_960h_2.yaml +fairseq/examples/wav2vec/config/finetuning/vox_960h_2_aws.yaml +fairseq/examples/wav2vec/config/finetuning/vox_960h_3.yaml +fairseq/examples/wav2vec/config/finetuning/run_config/slurm_1.yaml +fairseq/examples/wav2vec/config/finetuning/run_config/slurm_16.yaml +fairseq/examples/wav2vec/config/finetuning/run_config/slurm_1_aws.yaml +fairseq/examples/wav2vec/config/finetuning/run_config/slurm_1_old.yaml +fairseq/examples/wav2vec/config/finetuning/run_config/slurm_2.yaml +fairseq/examples/wav2vec/config/finetuning/run_config/slurm_2_aws.yaml +fairseq/examples/wav2vec/config/finetuning/run_config/slurm_2g.yaml +fairseq/examples/wav2vec/config/finetuning/run_config/slurm_3.yaml +fairseq/examples/wav2vec/config/finetuning/run_config/slurm_4g.yaml +fairseq/examples/wav2vec/config/finetuning/run_config/slurm_4g_aws.yaml +fairseq/examples/wav2vec/config/finetuning/run_config/slurm_8.yaml +fairseq/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml +fairseq/examples/wav2vec/config/pretraining/wav2vec2_conformer_base_librispeech.yaml +fairseq/examples/wav2vec/config/pretraining/wav2vec2_conformer_large_librivox.yaml +fairseq/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml +fairseq/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml +fairseq/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml +fairseq/examples/wav2vec/scripts/binarize_manifest.sh +fairseq/examples/wav2vec/unsupervised/README.md +fairseq/examples/wav2vec/unsupervised/__init__.py +fairseq/examples/wav2vec/unsupervised/w2vu_generate.py +fairseq/examples/wav2vec/unsupervised/config/finetuning/w2v_finetune.yaml +fairseq/examples/wav2vec/unsupervised/config/gan/w2vu.yaml +fairseq/examples/wav2vec/unsupervised/config/gan/w2vu2.yaml +fairseq/examples/wav2vec/unsupervised/config/generate/viterbi.yaml +fairseq/examples/wav2vec/unsupervised/config/timit_matched/test.uid +fairseq/examples/wav2vec/unsupervised/config/timit_matched/train.uid +fairseq/examples/wav2vec/unsupervised/config/timit_matched/train_text.uid +fairseq/examples/wav2vec/unsupervised/config/timit_matched/valid.uid +fairseq/examples/wav2vec/unsupervised/config/timit_unmatched/test.uid +fairseq/examples/wav2vec/unsupervised/config/timit_unmatched/train.uid +fairseq/examples/wav2vec/unsupervised/config/timit_unmatched/train_text.uid +fairseq/examples/wav2vec/unsupervised/config/timit_unmatched/valid.uid +fairseq/examples/wav2vec/unsupervised/data/__init__.py +fairseq/examples/wav2vec/unsupervised/data/extracted_features_dataset.py +fairseq/examples/wav2vec/unsupervised/data/random_input_dataset.py +fairseq/examples/wav2vec/unsupervised/kaldi_self_train/README.md +fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/cmd.sh +fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/decode_phone.sh +fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/decode_word_step1.sh +fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/decode_word_step2.sh +fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/path.sh +fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/train.sh +fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/copy_aligned_text.py +fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/decode.sh +fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_data_from_w2v.py +fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lang.sh +fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lang_word.sh +fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lm.sh +fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/score.sh +fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/show_wer.sh +fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/train_subset_lgbeam.sh +fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select.py +fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode.sh +fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode_word.sh +fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_deltas.sh +fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_lda_mllt.sh +fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_sat.sh +fairseq/examples/wav2vec/unsupervised/models/__init__.py +fairseq/examples/wav2vec/unsupervised/models/wav2vec_u.py +fairseq/examples/wav2vec/unsupervised/scripts/apply_pca.py +fairseq/examples/wav2vec/unsupervised/scripts/copy_labels.py +fairseq/examples/wav2vec/unsupervised/scripts/filter_lexicon.py +fairseq/examples/wav2vec/unsupervised/scripts/filter_tsv.py +fairseq/examples/wav2vec/unsupervised/scripts/g2p_wrd_to_phn.py +fairseq/examples/wav2vec/unsupervised/scripts/ltr_to_wrd.py +fairseq/examples/wav2vec/unsupervised/scripts/mean_pool.py +fairseq/examples/wav2vec/unsupervised/scripts/merge_clusters.py +fairseq/examples/wav2vec/unsupervised/scripts/normalize_and_filter_text.py +fairseq/examples/wav2vec/unsupervised/scripts/normalize_text.py +fairseq/examples/wav2vec/unsupervised/scripts/pca.py +fairseq/examples/wav2vec/unsupervised/scripts/phonemize_with_sil.py +fairseq/examples/wav2vec/unsupervised/scripts/prepare_audio.sh +fairseq/examples/wav2vec/unsupervised/scripts/prepare_audio_v2.sh +fairseq/examples/wav2vec/unsupervised/scripts/prepare_text.sh +fairseq/examples/wav2vec/unsupervised/scripts/prepare_timit.sh +fairseq/examples/wav2vec/unsupervised/scripts/remove_silence.py +fairseq/examples/wav2vec/unsupervised/scripts/vads.py +fairseq/examples/wav2vec/unsupervised/scripts/wav2vec_apply_cluster_faiss.py +fairseq/examples/wav2vec/unsupervised/scripts/wav2vec_cluster_faiss.py +fairseq/examples/wav2vec/unsupervised/scripts/wav2vec_extract_features.py +fairseq/examples/wav2vec/unsupervised/scripts/wer.py +fairseq/examples/wav2vec/unsupervised/scripts/wrd_to_ltr.py +fairseq/examples/wav2vec/unsupervised/tasks/__init__.py +fairseq/examples/wav2vec/unsupervised/tasks/unpaired_audio_text.py +fairseq/examples/wav2vec/xlsr/README.md +fairseq/examples/wav2vec/xlsr/config/finetune.yaml +fairseq/examples/wav2vec/xlsr/scripts/eval_speaker_clf_task.py +fairseq/examples/wav2vec/xlsr/scripts/gen_audio_embedding.py +fairseq/examples/wmt19/README.md +fairseq/examples/wmt20/README.md +fairseq/examples/wmt21/README.md +fairseq/examples/wmt21/eval.sh +fairseq/examples/wmt21/scripts/normalize-punctuation.perl +fairseq/examples/wmt21/scripts/replace-unicode-punctuation.perl +fairseq/examples/womens_bios/README.md +fairseq/examples/womens_bios/query_occupations_from_wikidata.py +fairseq/examples/xformers/README.md +fairseq/examples/xglm/README.md +fairseq/examples/xglm/XStoryCloze.md +fairseq/examples/xglm/model_card.md +fairseq/examples/xlmr/README.md +fairseq/examples/xmod/README.md +fairseq/examples/xmod/preprocess_nli.py +fairseq/logging/__init__.py +fairseq/logging/meters.py +fairseq/logging/metrics.py +fairseq/logging/progress_bar.py +fairseq/model_parallel/__init__.py +fairseq/model_parallel/megatron_trainer.py +fairseq/model_parallel/criterions/__init__.py +fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py +fairseq/model_parallel/models/__init__.py +fairseq/model_parallel/models/transformer.py +fairseq/model_parallel/models/transformer_lm.py +fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py +fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py +fairseq/model_parallel/models/pipeline_parallel_transformer/model.py +fairseq/model_parallel/models/roberta/__init__.py +fairseq/model_parallel/models/roberta/model.py +fairseq/model_parallel/modules/__init__.py +fairseq/model_parallel/modules/multihead_attention.py +fairseq/model_parallel/modules/transformer_layer.py +fairseq/models/__init__.py +fairseq/models/composite_encoder.py +fairseq/models/distributed_fairseq_model.py +fairseq/models/fairseq_decoder.py +fairseq/models/fairseq_encoder.py +fairseq/models/fairseq_incremental_decoder.py +fairseq/models/fairseq_model.py +fairseq/models/fconv.py +fairseq/models/fconv_lm.py +fairseq/models/fconv_self_att.py +fairseq/models/lightconv.py +fairseq/models/lightconv_lm.py +fairseq/models/lstm.py +fairseq/models/lstm_lm.py +fairseq/models/masked_lm.py +fairseq/models/model_utils.py +fairseq/models/multilingual_transformer.py +fairseq/models/transformer_align.py +fairseq/models/transformer_from_pretrained_xlm.py +fairseq/models/transformer_lm.py +fairseq/models/transformer_ulm.py +fairseq/models/bart/__init__.py +fairseq/models/bart/hub_interface.py +fairseq/models/bart/model.py +fairseq/models/ema/__init__.py +fairseq/models/ema/ema.py +fairseq/models/hubert/__init__.py +fairseq/models/hubert/hubert.py +fairseq/models/hubert/hubert_asr.py +fairseq/models/huggingface/__init__.py +fairseq/models/huggingface/hf_gpt2.py +fairseq/models/multires_hubert/__init__.py +fairseq/models/multires_hubert/multires_hubert.py +fairseq/models/multires_hubert/multires_hubert_asr.py +fairseq/models/nat/__init__.py +fairseq/models/nat/cmlm_transformer.py +fairseq/models/nat/fairseq_nat_model.py +fairseq/models/nat/insertion_transformer.py +fairseq/models/nat/iterative_nonautoregressive_transformer.py +fairseq/models/nat/levenshtein_transformer.py +fairseq/models/nat/levenshtein_utils.py +fairseq/models/nat/nat_crf_transformer.py +fairseq/models/nat/nonautoregressive_ensembles.py +fairseq/models/nat/nonautoregressive_transformer.py +fairseq/models/roberta/__init__.py +fairseq/models/roberta/alignment_utils.py +fairseq/models/roberta/enc_dec.py +fairseq/models/roberta/hub_interface.py +fairseq/models/roberta/model.py +fairseq/models/roberta/model_camembert.py +fairseq/models/roberta/model_gottbert.py +fairseq/models/roberta/model_xlmr.py +fairseq/models/speech_dlm/__init__.py +fairseq/models/speech_dlm/hub_interface.py +fairseq/models/speech_dlm/speech_dlm.py +fairseq/models/speech_dlm/modules/__init__.py +fairseq/models/speech_dlm/modules/speech_dlm_decoder.py +fairseq/models/speech_dlm/modules/speech_dlm_decoder_layer.py +fairseq/models/speech_dlm/sequence_generator/__init__.py +fairseq/models/speech_dlm/sequence_generator/multichannel_search.py +fairseq/models/speech_dlm/sequence_generator/multichannel_sequence_generator.py +fairseq/models/speech_to_speech/__init__.py +fairseq/models/speech_to_speech/s2s_conformer.py +fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py +fairseq/models/speech_to_speech/s2s_conformer_unity.py +fairseq/models/speech_to_speech/s2s_transformer.py +fairseq/models/speech_to_speech/modules/__init__.py +fairseq/models/speech_to_speech/modules/ctc_decoder.py +fairseq/models/speech_to_speech/modules/stacked_embedding.py +fairseq/models/speech_to_speech/modules/transformer_decoder_aug.py +fairseq/models/speech_to_speech/modules/transformer_encoder.py +fairseq/models/speech_to_text/__init__.py +fairseq/models/speech_to_text/berard.py +fairseq/models/speech_to_text/convtransformer.py +fairseq/models/speech_to_text/hub_interface.py +fairseq/models/speech_to_text/multi_modality_model.py +fairseq/models/speech_to_text/s2t_conformer.py +fairseq/models/speech_to_text/s2t_transformer.py +fairseq/models/speech_to_text/s2t_wav_transformer.py +fairseq/models/speech_to_text/utils.py +fairseq/models/speech_to_text/xm_transformer.py +fairseq/models/speech_to_text/xm_transformer_unity.py +fairseq/models/speech_to_text/modules/__init__.py +fairseq/models/speech_to_text/modules/augmented_memory_attention.py +fairseq/models/speech_to_text/modules/convolution.py +fairseq/models/speech_to_text/modules/emformer.py +fairseq/models/text_to_speech/__init__.py +fairseq/models/text_to_speech/codehifigan.py +fairseq/models/text_to_speech/fastspeech2.py +fairseq/models/text_to_speech/hifigan.py +fairseq/models/text_to_speech/hub_interface.py +fairseq/models/text_to_speech/tacotron2.py +fairseq/models/text_to_speech/tts_transformer.py +fairseq/models/text_to_speech/vocoder.py +fairseq/models/transformer/__init__.py +fairseq/models/transformer/transformer_base.py +fairseq/models/transformer/transformer_config.py +fairseq/models/transformer/transformer_decoder.py +fairseq/models/transformer/transformer_decoder_aug.py +fairseq/models/transformer/transformer_encoder.py +fairseq/models/transformer/transformer_legacy.py +fairseq/models/wav2vec/__init__.py +fairseq/models/wav2vec/utils.py +fairseq/models/wav2vec/wav2vec.py +fairseq/models/wav2vec/wav2vec2.py +fairseq/models/wav2vec/wav2vec2_asr.py +fairseq/models/wav2vec/wav2vec2_classification.py +fairseq/models/wav2vec/wav2vec2_laser.py +fairseq/models/xmod/__init__.py +fairseq/models/xmod/hub_interface.py +fairseq/models/xmod/model.py +fairseq/models/xmod/transformer_layer_xmod.py +fairseq/modules/__init__.py +fairseq/modules/adaptive_input.py +fairseq/modules/adaptive_softmax.py +fairseq/modules/base_layer.py +fairseq/modules/beamable_mm.py +fairseq/modules/character_token_embedder.py +fairseq/modules/checkpoint_activations.py +fairseq/modules/conformer_layer.py +fairseq/modules/conv_tbc.py +fairseq/modules/cross_entropy.py +fairseq/modules/downsampled_multihead_attention.py +fairseq/modules/dynamic_convolution.py +fairseq/modules/dynamic_crf_layer.py +fairseq/modules/ema_module.py +fairseq/modules/espnet_multihead_attention.py +fairseq/modules/fairseq_dropout.py +fairseq/modules/fp32_batch_norm.py +fairseq/modules/fp32_group_norm.py +fairseq/modules/fp32_instance_norm.py +fairseq/modules/gelu.py +fairseq/modules/grad_multiply.py +fairseq/modules/gumbel_vector_quantizer.py +fairseq/modules/kmeans_attention.py +fairseq/modules/kmeans_vector_quantizer.py +fairseq/modules/layer_drop.py +fairseq/modules/layer_norm.py +fairseq/modules/learned_positional_embedding.py +fairseq/modules/lightweight_convolution.py +fairseq/modules/linearized_convolution.py +fairseq/modules/location_attention.py +fairseq/modules/lstm_cell_with_zoneout.py +fairseq/modules/multihead_attention.py +fairseq/modules/positional_embedding.py +fairseq/modules/positional_encoding.py +fairseq/modules/quant_noise.py +fairseq/modules/rotary_positional_embedding.py +fairseq/modules/same_pad.py +fairseq/modules/scalar_bias.py +fairseq/modules/sinusoidal_positional_embedding.py +fairseq/modules/sparse_multihead_attention.py +fairseq/modules/sparse_transformer_sentence_encoder.py +fairseq/modules/sparse_transformer_sentence_encoder_layer.py +fairseq/modules/transformer_layer.py +fairseq/modules/transformer_layer_aug.py +fairseq/modules/transformer_sentence_encoder.py +fairseq/modules/transformer_sentence_encoder_layer.py +fairseq/modules/transpose_last.py +fairseq/modules/unfold.py +fairseq/modules/vggblock.py +fairseq/modules/dynamicconv_layer/__init__.py +fairseq/modules/dynamicconv_layer/cuda_function_gen.py +fairseq/modules/dynamicconv_layer/dynamicconv_layer.py +fairseq/modules/dynamicconv_layer/setup.py +fairseq/modules/lightconv_layer/__init__.py +fairseq/modules/lightconv_layer/cuda_function_gen.py +fairseq/modules/lightconv_layer/lightconv_layer.py +fairseq/modules/lightconv_layer/setup.py +fairseq/modules/quantization/__init__.py +fairseq/modules/quantization/quantization_options.py +fairseq/modules/quantization/pq/__init__.py +fairseq/modules/quantization/pq/em.py +fairseq/modules/quantization/pq/pq.py +fairseq/modules/quantization/pq/utils.py +fairseq/modules/quantization/pq/modules/__init__.py +fairseq/modules/quantization/pq/modules/qconv.py +fairseq/modules/quantization/pq/modules/qemb.py +fairseq/modules/quantization/pq/modules/qlinear.py +fairseq/modules/quantization/scalar/__init__.py +fairseq/modules/quantization/scalar/ops.py +fairseq/modules/quantization/scalar/utils.py +fairseq/modules/quantization/scalar/modules/__init__.py +fairseq/modules/quantization/scalar/modules/qact.py +fairseq/modules/quantization/scalar/modules/qconv.py +fairseq/modules/quantization/scalar/modules/qemb.py +fairseq/modules/quantization/scalar/modules/qlinear.py +fairseq/optim/__init__.py +fairseq/optim/adadelta.py +fairseq/optim/adafactor.py +fairseq/optim/adagrad.py +fairseq/optim/adam.py +fairseq/optim/adamax.py +fairseq/optim/amp_optimizer.py +fairseq/optim/bmuf.py +fairseq/optim/composite.py +fairseq/optim/cpu_adam.py +fairseq/optim/dynamic_loss_scaler.py +fairseq/optim/fairseq_optimizer.py +fairseq/optim/fp16_optimizer.py +fairseq/optim/fused_adam.py +fairseq/optim/fused_lamb.py +fairseq/optim/nag.py +fairseq/optim/sgd.py +fairseq/optim/shard.py +fairseq/optim/lr_scheduler/__init__.py +fairseq/optim/lr_scheduler/cosine_lr_scheduler.py +fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py +fairseq/optim/lr_scheduler/fixed_schedule.py +fairseq/optim/lr_scheduler/inverse_square_root_schedule.py +fairseq/optim/lr_scheduler/manual_lr_scheduler.py +fairseq/optim/lr_scheduler/pass_through.py +fairseq/optim/lr_scheduler/polynomial_decay_schedule.py +fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py +fairseq/optim/lr_scheduler/step_lr_scheduler.py +fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py +fairseq/optim/lr_scheduler/triangular_lr_scheduler.py +fairseq/scoring/__init__.py +fairseq/scoring/bertscore.py +fairseq/scoring/bleu.py +fairseq/scoring/chrf.py +fairseq/scoring/meteor.py +fairseq/scoring/tokenizer.py +fairseq/scoring/wer.py +fairseq/tasks/__init__.py +fairseq/tasks/audio_classification.py +fairseq/tasks/audio_finetuning.py +fairseq/tasks/audio_pretraining.py +fairseq/tasks/cross_lingual_lm.py +fairseq/tasks/denoising.py +fairseq/tasks/fairseq_task.py +fairseq/tasks/frm_text_to_speech.py +fairseq/tasks/hubert_pretraining.py +fairseq/tasks/language_modeling.py +fairseq/tasks/legacy_masked_lm.py +fairseq/tasks/masked_lm.py +fairseq/tasks/multilingual_denoising.py +fairseq/tasks/multilingual_language_modeling.py +fairseq/tasks/multilingual_masked_lm.py +fairseq/tasks/multilingual_translation.py +fairseq/tasks/multires_hubert_pretraining.py +fairseq/tasks/nlu_finetuning.py +fairseq/tasks/online_backtranslation.py +fairseq/tasks/semisupervised_translation.py +fairseq/tasks/sentence_prediction.py +fairseq/tasks/sentence_prediction_adapters.py +fairseq/tasks/sentence_ranking.py +fairseq/tasks/simultaneous_translation.py +fairseq/tasks/span_masked_lm.py +fairseq/tasks/speech_dlm_task.py +fairseq/tasks/speech_to_speech.py +fairseq/tasks/speech_to_text.py +fairseq/tasks/speech_ulm_task.py +fairseq/tasks/text_to_speech.py +fairseq/tasks/translation.py +fairseq/tasks/translation_from_pretrained_bart.py +fairseq/tasks/translation_from_pretrained_xlm.py +fairseq/tasks/translation_lev.py +fairseq/tasks/translation_multi_simple_epoch.py +fairseq_cli/__init__.py +fairseq_cli/eval_lm.py +fairseq_cli/generate.py +fairseq_cli/hydra_train.py +fairseq_cli/hydra_validate.py +fairseq_cli/interactive.py +fairseq_cli/preprocess.py +fairseq_cli/score.py +fairseq_cli/train.py +fairseq_cli/validate.py +tests/test_activation_checkpointing.py +tests/test_amp_optimizer.py +tests/test_average_checkpoints.py +tests/test_backtranslation_dataset.py +tests/test_binaries.py +tests/test_binarizer.py +tests/test_character_token_embedder.py +tests/test_checkpoint_utils.py +tests/test_checkpoint_utils_for_task_level_attributes.py +tests/test_concat_dataset.py +tests/test_constraints.py +tests/test_convtbc.py +tests/test_data_utils.py +tests/test_dataclass_utils.py +tests/test_dataset.py +tests/test_dictionary.py +tests/test_ema.py +tests/test_espnet_multihead_attention.py +tests/test_export.py +tests/test_file_chunker_utils.py +tests/test_file_io.py +tests/test_fp16_optimizer.py +tests/test_hf_hub.py +tests/test_huffman.py +tests/test_inference_dropout.py +tests/test_iopath.py +tests/test_iterators.py +tests/test_label_smoothing.py +tests/test_lm_context_window.py +tests/test_lstm_jitable.py +tests/test_memory_efficient_fp16.py +tests/test_metrics.py +tests/test_multi_corpus_dataset.py +tests/test_multi_corpus_sampled_dataset.py +tests/test_multihead_attention.py +tests/test_noising.py +tests/test_online_backtranslation.py +tests/test_plasma_utils.py +tests/test_positional_encoding.py +tests/test_reproducibility.py +tests/test_resampling_dataset.py +tests/test_roberta.py +tests/test_rotary_positional_embedding.py +tests/test_sequence_generator.py +tests/test_sequence_scorer.py +tests/test_sparse_multihead_attention.py +tests/test_token_block_dataset.py +tests/test_train.py +tests/test_transformer.py +tests/test_utils.py +tests/test_valid_subset_checks.py \ No newline at end of file diff --git a/fairseq/fairseq.egg-info/entry_points.txt b/fairseq/fairseq.egg-info/entry_points.txt new file mode 100644 index 0000000000000000000000000000000000000000..60badfca717c90c02cbbb9a196ade2ab1580b32c --- /dev/null +++ b/fairseq/fairseq.egg-info/entry_points.txt @@ -0,0 +1,9 @@ +[console_scripts] +fairseq-eval-lm = fairseq_cli.eval_lm:cli_main +fairseq-generate = fairseq_cli.generate:cli_main +fairseq-hydra-train = fairseq_cli.hydra_train:cli_main +fairseq-interactive = fairseq_cli.interactive:cli_main +fairseq-preprocess = fairseq_cli.preprocess:cli_main +fairseq-score = fairseq_cli.score:cli_main +fairseq-train = fairseq_cli.train:cli_main +fairseq-validate = fairseq_cli.validate:cli_main diff --git a/fairseq/fairseq.egg-info/requires.txt b/fairseq/fairseq.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..78d4110dbb94c8b612977c849e9f50360e9e3e32 --- /dev/null +++ b/fairseq/fairseq.egg-info/requires.txt @@ -0,0 +1,22 @@ +cffi +cython +hydra-core<1.1,>=1.0.7 +omegaconf<2.1 +numpy>=1.21.3 +regex +sacrebleu>=1.4.12 +torch>=1.13 +tqdm +bitarray +torchaudio>=0.8.0 +scikit-learn +packaging + +[dev] +flake8 +pytest +black==22.3.0 + +[docs] +sphinx +sphinx-argparse diff --git a/fairseq/fairseq.egg-info/top_level.txt b/fairseq/fairseq.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..4a2f684b681cd733e3ae9c9df6dad7b6fca32993 --- /dev/null +++ b/fairseq/fairseq.egg-info/top_level.txt @@ -0,0 +1,4 @@ +alignment_train_cpu_binding +alignment_train_cuda_binding +fairseq +fairseq_cli diff --git a/fairseq/fairseq/__pycache__/incremental_decoding_utils.cpython-310.pyc b/fairseq/fairseq/__pycache__/incremental_decoding_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86bf2950da3e8d61006332ee34ff8862b34c92a0 Binary files /dev/null and b/fairseq/fairseq/__pycache__/incremental_decoding_utils.cpython-310.pyc differ diff --git a/fairseq/fairseq/__pycache__/iterative_refinement_generator.cpython-310.pyc b/fairseq/fairseq/__pycache__/iterative_refinement_generator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e18b4b07f04961ae97ce7a3ff0bd6fa1cf52af8c Binary files /dev/null and b/fairseq/fairseq/__pycache__/iterative_refinement_generator.cpython-310.pyc differ diff --git a/fairseq/fairseq/__pycache__/ngram_repeat_block.cpython-310.pyc b/fairseq/fairseq/__pycache__/ngram_repeat_block.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9356f319af1f9be7cf6b0f6488caee8b93328dc8 Binary files /dev/null and b/fairseq/fairseq/__pycache__/ngram_repeat_block.cpython-310.pyc differ diff --git a/fairseq/fairseq/__pycache__/pdb.cpython-310.pyc b/fairseq/fairseq/__pycache__/pdb.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2dd60756a208f27d88e939197aa828007aecc5f7 Binary files /dev/null and b/fairseq/fairseq/__pycache__/pdb.cpython-310.pyc differ diff --git a/fairseq/fairseq_cli/__init__.py b/fairseq/fairseq_cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fairseq/fairseq_cli/eval_lm.py b/fairseq/fairseq_cli/eval_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..dbd1450a9e025f887e497acc6ae431f573d475e1 --- /dev/null +++ b/fairseq/fairseq_cli/eval_lm.py @@ -0,0 +1,347 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Evaluate the perplexity of a trained language model. +""" + +import logging +import math +import os +import sys +from argparse import Namespace +from typing import Iterable, List, Optional + +import torch +from omegaconf import DictConfig + +import fairseq +from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.logging import progress_bar +from fairseq.logging.meters import StopwatchMeter +from fairseq.sequence_scorer import SequenceScorer + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("fairseq_cli.eval_lm") + + +def eval_lm( + models: List[fairseq.models.FairseqModel], + source_dictionary: fairseq.data.Dictionary, + batch_iterator: Iterable, + post_process: Optional[str] = None, + output_word_probs: bool = False, + output_word_stats: bool = False, + target_dictionary: Optional[fairseq.data.Dictionary] = None, + softmax_batch: int = 0, + remove_bos_token: bool = False, + device: Optional[torch.device] = None, +): + """ + Args: + models (List[~fairseq.models.FairseqModel]): list of models to + evaluate. Models are essentially `nn.Module` instances, but + must be compatible with fairseq's `SequenceScorer`. + source_dictionary (~fairseq.data.Dictionary): dictionary for + applying any relevant post processing or outputing word + probs/stats. + batch_iterator (Iterable): yield batches of data + post_process (Optional[str]): post-process text by removing BPE, + letter segmentation, etc. Valid options can be found in + fairseq.data.utils.post_process, although not all options + are implemented here. + output_word_probs (Optional[bool]): output words and their + predicted log probabilities + output_word_stats (Optional[bool]): output word statistics such + as word count and average probability + target_dictionary (Optional[~fairseq.data.Dictionary]): output + dictionary (defaults to *source_dictionary*) + softmax_batch (Optional[bool]): if BxT is more than this, will + batch the softmax over vocab to this amount of tokens, in + order to fit into GPU memory + remove_bos_token (Optional[bool]): if True, confirm that the + first token is the beginning-of-sentence symbol (according + to the relevant dictionary) and remove it from the output + device (Optional[torch.device]): device to use for evaluation + (defaults to device of first model parameter) + """ + if target_dictionary is None: + target_dictionary = source_dictionary + if device is None: + device = next(models[0].parameters()).device + + gen_timer = StopwatchMeter() + scorer = SequenceScorer(target_dictionary, softmax_batch) + + score_sum = 0.0 + count = 0 + + if post_process is not None: + if post_process in {"subword_nmt", "@@ "}: + bpe_cont = post_process.rstrip() + bpe_toks = { + i + for i in range(len(source_dictionary)) + if source_dictionary[i].endswith(bpe_cont) + } + else: + raise NotImplementedError( + f"--post-process={post_process} is not implemented" + ) + bpe_len = len(bpe_cont) + else: + bpe_toks = None + bpe_len = 0 + + word_stats = dict() + + for sample in batch_iterator: + if "net_input" not in sample: + continue + + sample = utils.move_to_cuda(sample, device=device) + + gen_timer.start() + hypos = scorer.generate(models, sample) + gen_timer.stop(sample["ntokens"]) + + for i, hypos_i in enumerate(hypos): + hypo = hypos_i[0] + sample_id = sample["id"][i] + + tokens = hypo["tokens"] + tgt_len = tokens.numel() + pos_scores = hypo["positional_scores"].float() + + if remove_bos_token: + assert hypo["tokens"][0].item() == target_dictionary.bos() + tokens = tokens[1:] + pos_scores = pos_scores[1:] + + skipped_toks = 0 + if bpe_toks is not None: + for i in range(tgt_len - 1): + if tokens[i].item() in bpe_toks: + skipped_toks += 1 + pos_scores[i + 1] += pos_scores[i] + pos_scores[i] = 0 + + inf_scores = pos_scores.eq(float("inf")) | pos_scores.eq(float("-inf")) + if inf_scores.any(): + logger.info( + "skipping tokens with inf scores:", + target_dictionary.string(tokens[inf_scores.nonzero()]), + ) + pos_scores = pos_scores[(~inf_scores).nonzero()] + score_sum += pos_scores.sum().cpu() + count += pos_scores.numel() - skipped_toks + + if output_word_probs or output_word_stats: + w = "" + word_prob = [] + is_bpe = False + for i in range(len(tokens)): + w_ind = tokens[i].item() + w += source_dictionary[w_ind] + if bpe_toks is not None and w_ind in bpe_toks: + w = w[:-bpe_len] + is_bpe = True + else: + word_prob.append((w, pos_scores[i].item())) + + next_prob = None + ind = i + 1 + while ind < len(tokens): + if pos_scores[ind].item() != 0: + next_prob = pos_scores[ind] + break + ind += 1 + + word_stats.setdefault(w, WordStat(w, is_bpe)).add( + pos_scores[i].item(), next_prob + ) + is_bpe = False + w = "" + if output_word_probs: + logger.info( + str(int(sample_id)) + + " " + + ( + "\t".join( + "{} [{:2f}]".format(x[0], x[1]) for x in word_prob + ) + ) + ) + + avg_nll_loss = ( + -score_sum / count / math.log(2) if count > 0 else 0 + ) # convert to base 2 + logger.info( + "Evaluated {:,} tokens in {:.1f}s ({:.2f} tokens/s)".format( + gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg if gen_timer.avg > 0 else 0 + ) + ) + + if output_word_stats: + for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): + logger.info(ws) + + return { + "loss": avg_nll_loss, + "perplexity": 2**avg_nll_loss, + } + + +class WordStat(object): + def __init__(self, word, is_bpe): + self.word = word + self.is_bpe = is_bpe + self.log_prob = 0 + self.next_word_prob = 0 + self.count = 0 + self.missing_next_words = 0 + + def add(self, log_prob, next_word_prob): + """increments counters for the sum of log probs of current word and next + word (given context ending at current word). Since the next word might be at the end of the example, + or it might be not counted because it is not an ending subword unit, + also keeps track of how many of those we have seen""" + if next_word_prob is not None: + self.next_word_prob += next_word_prob + else: + self.missing_next_words += 1 + self.log_prob += log_prob + self.count += 1 + + def __str__(self): + return "{}\t{}\t{}\t{}\t{}\t{}".format( + self.word, + self.count, + self.log_prob, + self.is_bpe, + self.next_word_prob, + self.count - self.missing_next_words, + ) + + +def main(cfg: DictConfig, **unused_kwargs): + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + utils.import_user_module(cfg.common) + + logger.info(cfg) + + if cfg.eval_lm.context_window > 0: + # reduce tokens per sample by the required context window size + cfg.task.tokens_per_sample -= cfg.eval_lm.context_window + + # Initialize the task using the current *cfg* + task = tasks.setup_task(cfg.task) + + # Load ensemble + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) + models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( + [cfg.common_eval.path], + arg_overrides=eval(cfg.common_eval.model_overrides), + suffix=cfg.checkpoint.checkpoint_suffix, + strict=(cfg.checkpoint.checkpoint_shard_count == 1), + num_shards=cfg.checkpoint.checkpoint_shard_count, + task=task, + ) + + use_fp16 = cfg.common.fp16 + use_cuda = torch.cuda.is_available() and not cfg.common.cpu + if use_cuda: + torch.cuda.set_device(cfg.distributed_training.device_id) + + # Optimize ensemble for generation and set the source and dest dicts on the model + # (required by scorer) + for model in models: + if use_fp16: + model.half() + if use_cuda and not cfg.distributed_training.pipeline_model_parallel: + model.cuda() + model.prepare_for_inference_(cfg) + + assert len(models) > 0 + + logger.info( + "num. model params: {:,}".format(sum(p.numel() for p in models[0].parameters())) + ) + + # Load dataset splits + task.load_dataset(cfg.dataset.gen_subset) + dataset = task.dataset(cfg.dataset.gen_subset) + logger.info( + "{} {} {:,} examples".format( + cfg.task.data, cfg.dataset.gen_subset, len(dataset) + ) + ) + + itr = task.eval_lm_dataloader( + dataset=dataset, + max_tokens=cfg.dataset.max_tokens or 36000, + batch_size=cfg.dataset.batch_size, + max_positions=utils.resolve_max_positions( + *[model.max_positions() for model in models] + ), + num_shards=max( + cfg.dataset.num_shards, + cfg.distributed_training.distributed_world_size, + ), + shard_id=max( + cfg.dataset.shard_id, + cfg.distributed_training.distributed_rank, + ), + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, + context_window=cfg.eval_lm.context_window, + ) + + itr = progress_bar.progress_bar( + itr, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + ) + + results = eval_lm( + models=models, + source_dictionary=task.source_dictionary, + batch_iterator=itr, + post_process=cfg.common_eval.post_process, + output_word_probs=cfg.eval_lm.output_word_probs, + output_word_stats=cfg.eval_lm.output_word_stats, + target_dictionary=task.target_dictionary, + softmax_batch=cfg.eval_lm.softmax_batch, + remove_bos_token=getattr(cfg.task, "add_bos_token", False), + ) + + logger.info( + "Loss (base 2): {:.4f}, Perplexity: {:.2f}".format( + results["loss"], results["perplexity"] + ) + ) + + return results + + +def cli_main(): + parser = options.get_eval_lm_parser() + args = options.parse_args_and_arch(parser) + + distributed_utils.call_main(convert_namespace_to_omegaconf(args), main) + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq/fairseq_cli/generate.py b/fairseq/fairseq_cli/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..b8757835d406bf8bc3782db54b63b044c34ed184 --- /dev/null +++ b/fairseq/fairseq_cli/generate.py @@ -0,0 +1,417 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Translate pre-processed data with a trained model. +""" + +import ast +import logging +import math +import os +import sys +from argparse import Namespace +from itertools import chain + +import numpy as np +import torch +from omegaconf import DictConfig + +from fairseq import checkpoint_utils, options, scoring, tasks, utils +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.logging import progress_bar +from fairseq.logging.meters import StopwatchMeter, TimeMeter + + +def main(cfg: DictConfig): + + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + assert cfg.common_eval.path is not None, "--path required for generation!" + assert ( + not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam + ), "--sampling requires --nbest to be equal to --beam" + assert ( + cfg.generation.replace_unk is None or cfg.dataset.dataset_impl == "raw" + ), "--replace-unk requires a raw text dataset (--dataset-impl=raw)" + + if cfg.common_eval.results_path is not None: + os.makedirs(cfg.common_eval.results_path, exist_ok=True) + output_path = os.path.join( + cfg.common_eval.results_path, + "generate-{}.txt".format(cfg.dataset.gen_subset), + ) + with open(output_path, "w", buffering=1, encoding="utf-8") as h: + return _main(cfg, h) + else: + return _main(cfg, sys.stdout) + + +def get_symbols_to_strip_from_output(generator): + if hasattr(generator, "symbols_to_strip_from_output"): + return generator.symbols_to_strip_from_output + else: + return {generator.eos} + + +def _main(cfg: DictConfig, output_file): + logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=output_file, + ) + logger = logging.getLogger("fairseq_cli.generate") + + utils.import_user_module(cfg.common) + + if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: + cfg.dataset.max_tokens = 12000 + logger.info(cfg) + + # Fix seed for stochastic decoding + if cfg.common.seed is not None and not cfg.generation.no_seed_provided: + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) + + use_cuda = torch.cuda.is_available() and not cfg.common.cpu + + # Load dataset splits + task = tasks.setup_task(cfg.task) + + # Set dictionaries + try: + src_dict = getattr(task, "source_dictionary", None) + except NotImplementedError: + src_dict = None + tgt_dict = task.target_dictionary + + overrides = ast.literal_eval(cfg.common_eval.model_overrides) + + # Load ensemble + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) + models, saved_cfg = checkpoint_utils.load_model_ensemble( + utils.split_paths(cfg.common_eval.path), + arg_overrides=overrides, + task=task, + suffix=cfg.checkpoint.checkpoint_suffix, + strict=(cfg.checkpoint.checkpoint_shard_count == 1), + num_shards=cfg.checkpoint.checkpoint_shard_count, + ) + + # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config + task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task) + + if cfg.generation.lm_path is not None: + overrides["data"] = cfg.task.data + + try: + lms, _ = checkpoint_utils.load_model_ensemble( + [cfg.generation.lm_path], arg_overrides=overrides, task=None + ) + except: + logger.warning( + f"Failed to load language model! Please make sure that the language model dict is the same " + f"as target dict and is located in the data dir ({cfg.task.data})" + ) + raise + + assert len(lms) == 1 + else: + lms = [None] + + # Optimize ensemble for generation + for model in chain(models, lms): + if model is None: + continue + if cfg.common.fp16: + model.half() + if use_cuda and not cfg.distributed_training.pipeline_model_parallel: + model.cuda() + model.prepare_for_inference_(cfg) + + # Load alignment dictionary for unknown word replacement + # (None if no unknown word replacement, empty if no path to align dictionary) + align_dict = utils.load_align_dict(cfg.generation.replace_unk) + + # Load dataset (possibly sharded) + itr = task.get_batch_iterator( + dataset=task.dataset(cfg.dataset.gen_subset), + max_tokens=cfg.dataset.max_tokens, + max_sentences=cfg.dataset.batch_size, + max_positions=utils.resolve_max_positions( + task.max_positions(), *[m.max_positions() for m in models] + ), + ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, + seed=cfg.common.seed, + num_shards=cfg.distributed_training.distributed_world_size, + shard_id=cfg.distributed_training.distributed_rank, + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, + ).next_epoch_itr(shuffle=False) + progress = progress_bar.progress_bar( + itr, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + ) + + # Initialize generator + gen_timer = StopwatchMeter() + + extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight} + generator = task.build_generator( + models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs + ) + + # Handle tokenization and BPE + tokenizer = task.build_tokenizer(cfg.tokenizer) + bpe = task.build_bpe(cfg.bpe) + + def decode_fn(x): + if bpe is not None: + x = bpe.decode(x) + if tokenizer is not None: + x = tokenizer.decode(x) + return x + + scorer = scoring.build_scorer(cfg.scoring, tgt_dict) + + num_sentences = 0 + has_target = True + wps_meter = TimeMeter() + for sample in progress: + sample = utils.move_to_cuda(sample) if use_cuda else sample + if "net_input" not in sample: + continue + + prefix_tokens = None + if cfg.generation.prefix_size > 0: + prefix_tokens = sample["target"][:, : cfg.generation.prefix_size] + + constraints = None + if "constraints" in sample: + constraints = sample["constraints"] + + gen_timer.start() + hypos = task.inference_step( + generator, + models, + sample, + prefix_tokens=prefix_tokens, + constraints=constraints, + ) + num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) + gen_timer.stop(num_generated_tokens) + + for i, sample_id in enumerate(sample["id"].tolist()): + has_target = sample["target"] is not None + + # Remove padding + if "src_tokens" in sample["net_input"]: + src_tokens = utils.strip_pad( + sample["net_input"]["src_tokens"][i, :], tgt_dict.pad() + ) + else: + src_tokens = None + + target_tokens = None + if has_target: + target_tokens = ( + utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu() + ) + + # Either retrieve the original sentences or regenerate them from tokens. + if align_dict is not None: + src_str = task.dataset(cfg.dataset.gen_subset).src.get_original_text( + sample_id + ) + target_str = task.dataset(cfg.dataset.gen_subset).tgt.get_original_text( + sample_id + ) + else: + if src_dict is not None: + src_str = src_dict.string(src_tokens, cfg.common_eval.post_process) + else: + src_str = "" + if has_target: + target_str = tgt_dict.string( + target_tokens, + cfg.common_eval.post_process, + escape_unk=True, + extra_symbols_to_ignore=get_symbols_to_strip_from_output( + generator + ), + ) + + src_str = decode_fn(src_str) + if has_target: + target_str = decode_fn(target_str) + + if not cfg.common_eval.quiet: + if src_dict is not None: + print("S-{}\t{}".format(sample_id, src_str), file=output_file) + if has_target: + print("T-{}\t{}".format(sample_id, target_str), file=output_file) + + # Process top predictions + for j, hypo in enumerate(hypos[i][: cfg.generation.nbest]): + hypo_tokens, hypo_str, alignment = utils.post_process_prediction( + hypo_tokens=hypo["tokens"].int().cpu(), + src_str=src_str, + alignment=hypo["alignment"], + align_dict=align_dict, + tgt_dict=tgt_dict, + remove_bpe=cfg.common_eval.post_process, + extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), + ) + detok_hypo_str = decode_fn(hypo_str) + if not cfg.common_eval.quiet: + score = hypo["score"] / math.log(2) # convert to base 2 + # original hypothesis (after tokenization and BPE) + print( + "H-{}\t{}\t{}".format(sample_id, score, hypo_str), + file=output_file, + ) + # detokenized hypothesis + print( + "D-{}\t{}\t{}".format(sample_id, score, detok_hypo_str), + file=output_file, + ) + print( + "P-{}\t{}".format( + sample_id, + " ".join( + map( + lambda x: "{:.4f}".format(x), + # convert from base e to base 2 + hypo["positional_scores"] + .div_(math.log(2)) + .tolist(), + ) + ), + ), + file=output_file, + ) + + if cfg.generation.print_alignment == "hard": + print( + "A-{}\t{}".format( + sample_id, + " ".join( + [ + "{}-{}".format(src_idx, tgt_idx) + for src_idx, tgt_idx in alignment + ] + ), + ), + file=output_file, + ) + if cfg.generation.print_alignment == "soft": + print( + "A-{}\t{}".format( + sample_id, + " ".join( + [",".join(src_probs) for src_probs in alignment] + ), + ), + file=output_file, + ) + + if cfg.generation.print_step: + print( + "I-{}\t{}".format(sample_id, hypo["steps"]), + file=output_file, + ) + + if cfg.generation.retain_iter_history: + for step, h in enumerate(hypo["history"]): + _, h_str, _ = utils.post_process_prediction( + hypo_tokens=h["tokens"].int().cpu(), + src_str=src_str, + alignment=None, + align_dict=None, + tgt_dict=tgt_dict, + remove_bpe=None, + ) + print( + "E-{}_{}\t{}".format(sample_id, step, h_str), + file=output_file, + ) + + # Score only the top hypothesis + if has_target and j == 0: + if ( + align_dict is not None + or cfg.common_eval.post_process is not None + ): + # Convert back to tokens for evaluation with unk replacement and/or without BPE + target_tokens = tgt_dict.encode_line( + target_str, add_if_not_exist=True + ) + hypo_tokens = tgt_dict.encode_line( + detok_hypo_str, add_if_not_exist=True + ) + if hasattr(scorer, "add_string"): + scorer.add_string(target_str, detok_hypo_str) + else: + scorer.add(target_tokens, hypo_tokens) + + wps_meter.update(num_generated_tokens) + progress.log({"wps": round(wps_meter.avg)}) + num_sentences += ( + sample["nsentences"] if "nsentences" in sample else sample["id"].numel() + ) + + logger.info("NOTE: hypothesis and token scores are output in base 2") + logger.info( + "Translated {:,} sentences ({:,} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format( + num_sentences, + gen_timer.n, + gen_timer.sum, + num_sentences / gen_timer.sum, + 1.0 / gen_timer.avg, + ) + ) + if has_target: + if cfg.bpe and not cfg.generation.sacrebleu: + if cfg.common_eval.post_process: + logger.warning( + "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization" + ) + else: + logger.warning( + "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization" + ) + # use print to be consistent with other main outputs: S-, H-, T-, D- and so on + print( + "Generate {} with beam={}: {}".format( + cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string() + ), + file=output_file, + ) + + return scorer + + +def cli_main(): + parser = options.get_generation_parser() + # TODO: replace this workaround with refactoring of `AudioPretraining` + parser.add_argument( + "--arch", + "-a", + metavar="ARCH", + default="wav2vec2", + help="Model architecture. For constructing tasks that rely on " + "model args (e.g. `AudioPretraining`)", + ) + args = options.parse_args_and_arch(parser) + main(args) + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq/fairseq_cli/hydra_train.py b/fairseq/fairseq_cli/hydra_train.py new file mode 100644 index 0000000000000000000000000000000000000000..607340af0df3326b9245a026366c7d6d004f0013 --- /dev/null +++ b/fairseq/fairseq_cli/hydra_train.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os + +import hydra +import torch +from hydra.core.hydra_config import HydraConfig +from omegaconf import OmegaConf, open_dict + +from fairseq import distributed_utils, metrics +from fairseq.dataclass.configs import FairseqConfig +from fairseq.dataclass.initialize import add_defaults, hydra_init +from fairseq.dataclass.utils import omegaconf_no_object_check +from fairseq.utils import reset_logging +from fairseq_cli.train import main as pre_main + +logger = logging.getLogger("fairseq_cli.hydra_train") + + +@hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config") +def hydra_main(cfg: FairseqConfig) -> float: + _hydra_main(cfg) + + +def _hydra_main(cfg: FairseqConfig, **kwargs) -> float: + add_defaults(cfg) + + if cfg.common.reset_logging: + reset_logging() # Hydra hijacks logging, fix that + else: + # check if directly called or called through hydra_main + if HydraConfig.initialized(): + with open_dict(cfg): + # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) + cfg.job_logging_cfg = OmegaConf.to_container( + HydraConfig.get().job_logging, resolve=True + ) + + with omegaconf_no_object_check(): + cfg = OmegaConf.create( + OmegaConf.to_container(cfg, resolve=True, enum_to_str=True) + ) + OmegaConf.set_struct(cfg, True) + + try: + if cfg.common.profile: + with torch.cuda.profiler.profile(): + with torch.autograd.profiler.emit_nvtx(): + distributed_utils.call_main(cfg, pre_main, **kwargs) + else: + distributed_utils.call_main(cfg, pre_main, **kwargs) + except BaseException as e: + if not cfg.common.suppress_crashes: + raise + else: + logger.error("Crashed! " + str(e)) + + # get best val and return - useful for sweepers + try: + best_val = metrics.get_smoothed_value( + "valid", cfg.checkpoint.best_checkpoint_metric + ) + except: + best_val = None + + if best_val is None: + best_val = float("inf") + + return best_val + + +def cli_main(): + try: + from hydra._internal.utils import get_args + + cfg_name = get_args().config_name or "config" + except: + logger.warning("Failed to get config name from hydra args") + cfg_name = "config" + + hydra_init(cfg_name) + hydra_main() + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq/fairseq_cli/hydra_validate.py b/fairseq/fairseq_cli/hydra_validate.py new file mode 100644 index 0000000000000000000000000000000000000000..cb6f7612d0d7d5c2581bdb8bd7ebfcda22c3a965 --- /dev/null +++ b/fairseq/fairseq_cli/hydra_validate.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import sys +from itertools import chain + +import torch +from hydra.core.hydra_config import HydraConfig +from omegaconf import OmegaConf, open_dict +import hydra + +from fairseq import checkpoint_utils, distributed_utils, utils +from fairseq.dataclass.configs import FairseqConfig +from fairseq.dataclass.initialize import add_defaults, hydra_init +from fairseq.dataclass.utils import omegaconf_no_object_check +from fairseq.distributed import utils as distributed_utils +from fairseq.logging import metrics, progress_bar +from fairseq.utils import reset_logging + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("fairseq_cli.validate") + + +@hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config") +def hydra_main(cfg: FairseqConfig) -> float: + return _hydra_main(cfg) + + +def _hydra_main(cfg: FairseqConfig, **kwargs) -> float: + add_defaults(cfg) + + if cfg.common.reset_logging: + reset_logging() # Hydra hijacks logging, fix that + else: + # check if directly called or called through hydra_main + if HydraConfig.initialized(): + with open_dict(cfg): + # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) + cfg.job_logging_cfg = OmegaConf.to_container( + HydraConfig.get().job_logging, resolve=True + ) + + with omegaconf_no_object_check(): + cfg = OmegaConf.create( + OmegaConf.to_container(cfg, resolve=True, enum_to_str=True) + ) + OmegaConf.set_struct(cfg, True) + + assert ( + cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None + ), "Must specify batch size either with --max-tokens or --batch-size" + + distributed_utils.call_main(cfg, validate, **kwargs) + + +def validate(cfg): + utils.import_user_module(cfg.common) + + use_fp16 = cfg.common.fp16 + use_cuda = torch.cuda.is_available() and not cfg.common.cpu + + if use_cuda: + torch.cuda.set_device(cfg.distributed_training.device_id) + + if cfg.distributed_training.distributed_world_size > 1: + data_parallel_world_size = distributed_utils.get_data_parallel_world_size() + data_parallel_rank = distributed_utils.get_data_parallel_rank() + else: + data_parallel_world_size = 1 + data_parallel_rank = 0 + + overrides = {"task": {"data": cfg.task.data}} + + # Load ensemble + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( + [cfg.common_eval.path], + arg_overrides=overrides, + suffix=cfg.checkpoint.checkpoint_suffix, + ) + model = models[0] + + # Move models to GPU + for model in models: + model.eval() + if use_fp16: + model.half() + if use_cuda: + model.cuda() + + # Print args + logger.info(saved_cfg) + + # Build criterion + criterion = task.build_criterion(saved_cfg.criterion, from_checkpoint=True) + criterion.eval() + + for subset in cfg.dataset.valid_subset.split(","): + try: + task.load_dataset(subset, combine=False, epoch=1, task_cfg=saved_cfg.task) + dataset = task.dataset(subset) + except KeyError: + raise Exception("Cannot find dataset: " + subset) + + # Initialize data iterator + itr = task.get_batch_iterator( + dataset=dataset, + max_tokens=cfg.dataset.max_tokens, + max_sentences=cfg.dataset.batch_size, + max_positions=utils.resolve_max_positions( + task.max_positions(), + *[m.max_positions() for m in models], + ), + ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, + seed=cfg.common.seed, + num_shards=data_parallel_world_size, + shard_id=data_parallel_rank, + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, + ).next_epoch_itr(shuffle=False) + progress = progress_bar.progress_bar( + itr, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + prefix=f"valid on '{subset}' subset", + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + ) + + def apply_half(t): + if t.dtype is torch.float32: + return t.to(dtype=torch.half) + return t + + log_outputs = [] + for i, sample in enumerate(progress): + sample = utils.move_to_cuda(sample) if use_cuda else sample + + if use_fp16: + sample = utils.apply_to_sample(apply_half, sample) + + _loss, _sample_size, log_output = task.valid_step(sample, model, criterion) + with metrics.aggregate() as agg: + task.reduce_metrics([log_output], criterion) + progress.log(agg.get_smoothed_values(), step=i) + # progress.log(log_output, step=i) from vision + log_outputs.append(log_output) + + if data_parallel_world_size > 1: + log_outputs = distributed_utils.all_gather_list( + log_outputs, + max_size=cfg.common.all_gather_list_size, + group=distributed_utils.get_data_parallel_group(), + ) + log_outputs = list(chain.from_iterable(log_outputs)) + + with metrics.aggregate() as agg: + task.reduce_metrics(log_outputs, criterion) + log_output = agg.get_smoothed_values() + + progress.print(log_output, tag=subset, step=i) + + +def cli_main(): + try: + from hydra._internal.utils import get_args + + cfg_name = get_args().config_name or "config" + except: + logger.warning("Failed to get config name from hydra args") + cfg_name = "config" + + hydra_init(cfg_name) + hydra_main() + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq/fairseq_cli/interactive.py b/fairseq/fairseq_cli/interactive.py new file mode 100644 index 0000000000000000000000000000000000000000..03265d00e81052bbe2296fc07bb0531994a63d7c --- /dev/null +++ b/fairseq/fairseq_cli/interactive.py @@ -0,0 +1,317 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Translate raw text with a trained model. Batches data on-the-fly. +""" + +import ast +import fileinput +import logging +import math +import os +import sys +import time +from argparse import Namespace +from collections import namedtuple + +import numpy as np +import torch + +from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils +from fairseq.dataclass.configs import FairseqConfig +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.token_generation_constraints import pack_constraints, unpack_constraints +from fairseq_cli.generate import get_symbols_to_strip_from_output + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("fairseq_cli.interactive") + + +Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints") +Translation = namedtuple("Translation", "src_str hypos pos_scores alignments") + + +def buffered_read(input, buffer_size): + buffer = [] + with fileinput.input(files=[input], openhook=fileinput.hook_encoded("utf-8")) as h: + for src_str in h: + buffer.append(src_str.strip()) + if len(buffer) >= buffer_size: + yield buffer + buffer = [] + + if len(buffer) > 0: + yield buffer + + +def make_batches(lines, cfg, task, max_positions, encode_fn): + def encode_fn_target(x): + return encode_fn(x) + + if cfg.generation.constraints: + # Strip (tab-delimited) contraints, if present, from input lines, + # store them in batch_constraints + batch_constraints = [list() for _ in lines] + for i, line in enumerate(lines): + if "\t" in line: + lines[i], *batch_constraints[i] = line.split("\t") + + # Convert each List[str] to List[Tensor] + for i, constraint_list in enumerate(batch_constraints): + batch_constraints[i] = [ + task.target_dictionary.encode_line( + encode_fn_target(constraint), + append_eos=False, + add_if_not_exist=False, + ) + for constraint in constraint_list + ] + + if cfg.generation.constraints: + constraints_tensor = pack_constraints(batch_constraints) + else: + constraints_tensor = None + + tokens, lengths = task.get_interactive_tokens_and_lengths(lines, encode_fn) + + itr = task.get_batch_iterator( + dataset=task.build_dataset_for_inference( + tokens, lengths, constraints=constraints_tensor + ), + max_tokens=cfg.dataset.max_tokens, + max_sentences=cfg.dataset.batch_size, + max_positions=max_positions, + ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, + ).next_epoch_itr(shuffle=False) + for batch in itr: + ids = batch["id"] + src_tokens = batch["net_input"]["src_tokens"] + src_lengths = batch["net_input"]["src_lengths"] + constraints = batch.get("constraints", None) + + yield Batch( + ids=ids, + src_tokens=src_tokens, + src_lengths=src_lengths, + constraints=constraints, + ) + + +def main(cfg: FairseqConfig): + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + start_time = time.time() + total_translate_time = 0 + + utils.import_user_module(cfg.common) + + if cfg.interactive.buffer_size < 1: + cfg.interactive.buffer_size = 1 + if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: + cfg.dataset.batch_size = 1 + + assert ( + not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam + ), "--sampling requires --nbest to be equal to --beam" + assert ( + not cfg.dataset.batch_size + or cfg.dataset.batch_size <= cfg.interactive.buffer_size + ), "--batch-size cannot be larger than --buffer-size" + + logger.info(cfg) + + # Fix seed for stochastic decoding + if cfg.common.seed is not None and not cfg.generation.no_seed_provided: + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) + + use_cuda = torch.cuda.is_available() and not cfg.common.cpu + + # Setup task, e.g., translation + task = tasks.setup_task(cfg.task) + + # Load ensemble + overrides = ast.literal_eval(cfg.common_eval.model_overrides) + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) + models, _model_args = checkpoint_utils.load_model_ensemble( + utils.split_paths(cfg.common_eval.path), + arg_overrides=overrides, + task=task, + suffix=cfg.checkpoint.checkpoint_suffix, + strict=(cfg.checkpoint.checkpoint_shard_count == 1), + num_shards=cfg.checkpoint.checkpoint_shard_count, + ) + + # Set dictionaries + src_dict = task.source_dictionary + tgt_dict = task.target_dictionary + + # Optimize ensemble for generation + for model in models: + if model is None: + continue + if cfg.common.fp16: + model.half() + if use_cuda and not cfg.distributed_training.pipeline_model_parallel: + model.cuda() + model.prepare_for_inference_(cfg) + + # Initialize generator + generator = task.build_generator(models, cfg.generation) + + # Handle tokenization and BPE + tokenizer = task.build_tokenizer(cfg.tokenizer) + bpe = task.build_bpe(cfg.bpe) + + def encode_fn(x): + if tokenizer is not None: + x = tokenizer.encode(x) + if bpe is not None: + x = bpe.encode(x) + return x + + def decode_fn(x): + if bpe is not None: + x = bpe.decode(x) + if tokenizer is not None: + x = tokenizer.decode(x) + return x + + # Load alignment dictionary for unknown word replacement + # (None if no unknown word replacement, empty if no path to align dictionary) + align_dict = utils.load_align_dict(cfg.generation.replace_unk) + + max_positions = utils.resolve_max_positions( + task.max_positions(), *[model.max_positions() for model in models] + ) + + if cfg.generation.constraints: + logger.warning( + "NOTE: Constrained decoding currently assumes a shared subword vocabulary." + ) + + if cfg.interactive.buffer_size > 1: + logger.info("Sentence buffer size: %s", cfg.interactive.buffer_size) + logger.info("NOTE: hypothesis and token scores are output in base 2") + logger.info("Type the input sentence and press return:") + start_id = 0 + for inputs in buffered_read(cfg.interactive.input, cfg.interactive.buffer_size): + results = [] + for batch in make_batches(inputs, cfg, task, max_positions, encode_fn): + bsz = batch.src_tokens.size(0) + src_tokens = batch.src_tokens + src_lengths = batch.src_lengths + constraints = batch.constraints + if use_cuda: + src_tokens = src_tokens.cuda() + src_lengths = src_lengths.cuda() + if constraints is not None: + constraints = constraints.cuda() + + sample = { + "net_input": { + "src_tokens": src_tokens, + "src_lengths": src_lengths, + }, + } + translate_start_time = time.time() + translations = task.inference_step( + generator, models, sample, constraints=constraints + ) + translate_time = time.time() - translate_start_time + total_translate_time += translate_time + list_constraints = [[] for _ in range(bsz)] + if cfg.generation.constraints: + list_constraints = [unpack_constraints(c) for c in constraints] + for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): + src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad()) + constraints = list_constraints[i] + results.append( + ( + start_id + id, + src_tokens_i, + hypos, + { + "constraints": constraints, + "time": translate_time / len(translations), + }, + ) + ) + + # sort output to match input order + for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): + src_str = "" + if src_dict is not None: + src_str = src_dict.string(src_tokens, cfg.common_eval.post_process) + print("S-{}\t{}".format(id_, src_str)) + print("W-{}\t{:.3f}\tseconds".format(id_, info["time"])) + for constraint in info["constraints"]: + print( + "C-{}\t{}".format( + id_, + tgt_dict.string(constraint, cfg.common_eval.post_process), + ) + ) + + # Process top predictions + for hypo in hypos[: min(len(hypos), cfg.generation.nbest)]: + hypo_tokens, hypo_str, alignment = utils.post_process_prediction( + hypo_tokens=hypo["tokens"].int().cpu(), + src_str=src_str, + alignment=hypo["alignment"], + align_dict=align_dict, + tgt_dict=tgt_dict, + remove_bpe=cfg.common_eval.post_process, + extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), + ) + detok_hypo_str = decode_fn(hypo_str) + score = hypo["score"] / math.log(2) # convert to base 2 + # original hypothesis (after tokenization and BPE) + print("H-{}\t{}\t{}".format(id_, score, hypo_str)) + # detokenized hypothesis + print("D-{}\t{}\t{}".format(id_, score, detok_hypo_str)) + print( + "P-{}\t{}".format( + id_, + " ".join( + map( + lambda x: "{:.4f}".format(x), + # convert from base e to base 2 + hypo["positional_scores"].div_(math.log(2)).tolist(), + ) + ), + ) + ) + if cfg.generation.print_alignment: + alignment_str = " ".join( + ["{}-{}".format(src, tgt) for src, tgt in alignment] + ) + print("A-{}\t{}".format(id_, alignment_str)) + + # update running id_ counter + start_id += len(inputs) + + logger.info( + "Total time: {:.3f} seconds; translation time: {:.3f}".format( + time.time() - start_time, total_translate_time + ) + ) + + +def cli_main(): + parser = options.get_interactive_generation_parser() + args = options.parse_args_and_arch(parser) + distributed_utils.call_main(convert_namespace_to_omegaconf(args), main) + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq/fairseq_cli/preprocess.py b/fairseq/fairseq_cli/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..2ba9e09338abd3b04b5d2497ce3f97613802bd6f --- /dev/null +++ b/fairseq/fairseq_cli/preprocess.py @@ -0,0 +1,393 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Data pre-processing: build vocabularies and binarize training data. +""" + +import logging +import os +import shutil +import sys +import typing as tp +from argparse import Namespace +from itertools import zip_longest + +from fairseq import options, tasks, utils +from fairseq.binarizer import ( + AlignmentDatasetBinarizer, + FileBinarizer, + VocabularyDatasetBinarizer, +) +from fairseq.data import Dictionary + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("fairseq_cli.preprocess") + +##################################################################### +# file name tools +##################################################################### + + +def _train_path(lang, trainpref): + return "{}{}".format(trainpref, ("." + lang) if lang else "") + + +def _file_name(prefix, lang): + fname = prefix + if lang is not None: + fname += ".{lang}".format(lang=lang) + return fname + + +def _dest_path(prefix, lang, destdir): + return os.path.join(destdir, _file_name(prefix, lang)) + + +def _dict_path(lang, destdir): + return _dest_path("dict", lang, destdir) + ".txt" + + +def dataset_dest_prefix(args, output_prefix, lang): + base = os.path.join(args.destdir, output_prefix) + if lang is not None: + lang_part = f".{args.source_lang}-{args.target_lang}.{lang}" + elif args.only_source: + lang_part = "" + else: + lang_part = f".{args.source_lang}-{args.target_lang}" + + return "{}{}".format(base, lang_part) + + +def dataset_dest_file(args, output_prefix, lang, extension): + return "{}.{}".format(dataset_dest_prefix(args, output_prefix, lang), extension) + + +##################################################################### +# dictionary tools +##################################################################### + + +def _build_dictionary( + filenames, + task, + args, + src=False, + tgt=False, +): + assert src ^ tgt + return task.build_dictionary( + filenames, + workers=args.workers, + threshold=args.thresholdsrc if src else args.thresholdtgt, + nwords=args.nwordssrc if src else args.nwordstgt, + padding_factor=args.padding_factor, + ) + + +##################################################################### +# bin file creation logic +##################################################################### + + +def _make_binary_dataset( + vocab: Dictionary, + input_prefix: str, + output_prefix: str, + lang: tp.Optional[str], + num_workers: int, + args: Namespace, +): + logger.info("[{}] Dictionary: {} types".format(lang, len(vocab))) + + binarizer = VocabularyDatasetBinarizer( + vocab, + append_eos=True, + ) + + input_file = "{}{}".format(input_prefix, ("." + lang) if lang is not None else "") + full_output_prefix = dataset_dest_prefix(args, output_prefix, lang) + + final_summary = FileBinarizer.multiprocess_dataset( + input_file, + args.dataset_impl, + binarizer, + full_output_prefix, + vocab_size=len(vocab), + num_workers=num_workers, + ) + + logger.info(f"[{lang}] {input_file}: {final_summary} (by {vocab.unk_word})") + + +def _make_binary_alignment_dataset( + input_prefix: str, output_prefix: str, num_workers: int, args: Namespace +): + + binarizer = AlignmentDatasetBinarizer(utils.parse_alignment) + + input_file = input_prefix + full_output_prefix = dataset_dest_prefix(args, output_prefix, lang=None) + + final_summary = FileBinarizer.multiprocess_dataset( + input_file, + args.dataset_impl, + binarizer, + full_output_prefix, + vocab_size=None, + num_workers=num_workers, + ) + + logger.info( + "[alignments] {}: parsed {} alignments".format( + input_file, final_summary.num_seq + ) + ) + + +##################################################################### +# routing logic +##################################################################### + + +def _make_dataset( + vocab: Dictionary, + input_prefix: str, + output_prefix: str, + lang: tp.Optional[str], + args: Namespace, + num_workers: int, +): + if args.dataset_impl == "raw": + # Copy original text file to destination folder + output_text_file = _dest_path( + output_prefix + ".{}-{}".format(args.source_lang, args.target_lang), + lang, + args.destdir, + ) + shutil.copyfile(_file_name(input_prefix, lang), output_text_file) + else: + _make_binary_dataset( + vocab, input_prefix, output_prefix, lang, num_workers, args + ) + + +def _make_all(lang, vocab, args): + if args.trainpref: + _make_dataset( + vocab, args.trainpref, "train", lang, args=args, num_workers=args.workers + ) + if args.validpref: + for k, validpref in enumerate(args.validpref.split(",")): + outprefix = "valid{}".format(k) if k > 0 else "valid" + _make_dataset( + vocab, validpref, outprefix, lang, args=args, num_workers=args.workers + ) + if args.testpref: + for k, testpref in enumerate(args.testpref.split(",")): + outprefix = "test{}".format(k) if k > 0 else "test" + _make_dataset( + vocab, testpref, outprefix, lang, args=args, num_workers=args.workers + ) + + +def _make_all_alignments(args): + if args.trainpref and os.path.exists(args.trainpref + "." + args.align_suffix): + _make_binary_alignment_dataset( + args.trainpref + "." + args.align_suffix, + "train.align", + num_workers=args.workers, + args=args, + ) + if args.validpref and os.path.exists(args.validpref + "." + args.align_suffix): + _make_binary_alignment_dataset( + args.validpref + "." + args.align_suffix, + "valid.align", + num_workers=args.workers, + args=args, + ) + if args.testpref and os.path.exists(args.testpref + "." + args.align_suffix): + _make_binary_alignment_dataset( + args.testpref + "." + args.align_suffix, + "test.align", + num_workers=args.workers, + args=args, + ) + + +##################################################################### +# align +##################################################################### + + +def _align_files(args, src_dict, tgt_dict): + assert args.trainpref, "--trainpref must be set if --alignfile is specified" + src_file_name = _train_path(args.source_lang, args.trainpref) + tgt_file_name = _train_path(args.target_lang, args.trainpref) + freq_map = {} + with open(args.alignfile, "r", encoding="utf-8") as align_file: + with open(src_file_name, "r", encoding="utf-8") as src_file: + with open(tgt_file_name, "r", encoding="utf-8") as tgt_file: + for a, s, t in zip_longest(align_file, src_file, tgt_file): + si = src_dict.encode_line(s, add_if_not_exist=False) + ti = tgt_dict.encode_line(t, add_if_not_exist=False) + ai = list(map(lambda x: tuple(x.split("-")), a.split())) + for sai, tai in ai: + srcidx = si[int(sai)] + tgtidx = ti[int(tai)] + if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk(): + assert srcidx != src_dict.pad() + assert srcidx != src_dict.eos() + assert tgtidx != tgt_dict.pad() + assert tgtidx != tgt_dict.eos() + if srcidx not in freq_map: + freq_map[srcidx] = {} + if tgtidx not in freq_map[srcidx]: + freq_map[srcidx][tgtidx] = 1 + else: + freq_map[srcidx][tgtidx] += 1 + align_dict = {} + for srcidx in freq_map.keys(): + align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get) + with open( + os.path.join( + args.destdir, + "alignment.{}-{}.txt".format(args.source_lang, args.target_lang), + ), + "w", + encoding="utf-8", + ) as f: + for k, v in align_dict.items(): + print("{} {}".format(src_dict[k], tgt_dict[v]), file=f) + + +##################################################################### +# MAIN +##################################################################### + + +def main(args): + # setup some basic things + utils.import_user_module(args) + + os.makedirs(args.destdir, exist_ok=True) + + logger.addHandler( + logging.FileHandler( + filename=os.path.join(args.destdir, "preprocess.log"), + ) + ) + logger.info(args) + + assert ( + args.dataset_impl != "huffman" + ), "preprocessing.py doesn't support Huffman yet, use HuffmanCodeBuilder directly." + + # build dictionaries + + target = not args.only_source + + if not args.srcdict and os.path.exists(_dict_path(args.source_lang, args.destdir)): + raise FileExistsError(_dict_path(args.source_lang, args.destdir)) + + if ( + target + and not args.tgtdict + and os.path.exists(_dict_path(args.target_lang, args.destdir)) + ): + raise FileExistsError(_dict_path(args.target_lang, args.destdir)) + + task = tasks.get_task(args.task) + + if args.joined_dictionary: + assert ( + not args.srcdict or not args.tgtdict + ), "cannot use both --srcdict and --tgtdict with --joined-dictionary" + + if args.srcdict: + src_dict = task.load_dictionary(args.srcdict) + elif args.tgtdict: + src_dict = task.load_dictionary(args.tgtdict) + else: + assert ( + args.trainpref + ), "--trainpref must be set if --srcdict is not specified" + src_dict = _build_dictionary( + { + _train_path(lang, args.trainpref) + for lang in [args.source_lang, args.target_lang] + }, + task=task, + args=args, + src=True, + ) + tgt_dict = src_dict + else: + if args.srcdict: + src_dict = task.load_dictionary(args.srcdict) + else: + assert ( + args.trainpref + ), "--trainpref must be set if --srcdict is not specified" + src_dict = _build_dictionary( + [_train_path(args.source_lang, args.trainpref)], + task=task, + args=args, + src=True, + ) + + if target: + if args.tgtdict: + tgt_dict = task.load_dictionary(args.tgtdict) + else: + assert ( + args.trainpref + ), "--trainpref must be set if --tgtdict is not specified" + tgt_dict = _build_dictionary( + [_train_path(args.target_lang, args.trainpref)], + task=task, + args=args, + tgt=True, + ) + else: + tgt_dict = None + + # save dictionaries + + src_dict.save(_dict_path(args.source_lang, args.destdir)) + if target and tgt_dict is not None: + tgt_dict.save(_dict_path(args.target_lang, args.destdir)) + + if args.dict_only: + return + + _make_all(args.source_lang, src_dict, args) + if target: + _make_all(args.target_lang, tgt_dict, args) + + # align the datasets if needed + if args.align_suffix: + _make_all_alignments(args) + + logger.info("Wrote preprocessed data to {}".format(args.destdir)) + + if args.alignfile: + _align_files(args, src_dict=src_dict, tgt_dict=tgt_dict) + + +def cli_main(): + parser = options.get_preprocessing_parser() + args = parser.parse_args() + main(args) + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq/fairseq_cli/score.py b/fairseq/fairseq_cli/score.py new file mode 100644 index 0000000000000000000000000000000000000000..0b207be959d55f6a56d8c5eb7db3dbe0c1ac977e --- /dev/null +++ b/fairseq/fairseq_cli/score.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +BLEU scoring of generated translations against reference translations. +""" + +import argparse +import os +import sys + +from fairseq.data import dictionary +from fairseq.scoring import bleu + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Command-line script for BLEU scoring." + ) + # fmt: off + parser.add_argument('-s', '--sys', default='-', help='system output') + parser.add_argument('-r', '--ref', required=True, help='references') + parser.add_argument('-o', '--order', default=4, metavar='N', + type=int, help='consider ngrams up to this order') + parser.add_argument('--ignore-case', action='store_true', + help='case-insensitive scoring') + parser.add_argument('--sacrebleu', action='store_true', + help='score with sacrebleu') + parser.add_argument('--sentence-bleu', action='store_true', + help='report sentence-level BLEUs (i.e., with +1 smoothing)') + # fmt: on + return parser + + +def cli_main(): + parser = get_parser() + args = parser.parse_args() + print(args) + + assert args.sys == "-" or os.path.exists( + args.sys + ), "System output file {} does not exist".format(args.sys) + assert os.path.exists(args.ref), "Reference file {} does not exist".format(args.ref) + + dict = dictionary.Dictionary() + + def readlines(fd): + for line in fd.readlines(): + if args.ignore_case: + yield line.lower() + else: + yield line + + if args.sacrebleu: + import sacrebleu + + def score(fdsys): + with open(args.ref) as fdref: + print(sacrebleu.corpus_bleu(fdsys, [fdref]).format()) + + elif args.sentence_bleu: + + def score(fdsys): + with open(args.ref) as fdref: + scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) + for i, (sys_tok, ref_tok) in enumerate( + zip(readlines(fdsys), readlines(fdref)) + ): + scorer.reset(one_init=True) + sys_tok = dict.encode_line(sys_tok) + ref_tok = dict.encode_line(ref_tok) + scorer.add(ref_tok, sys_tok) + print(i, scorer.result_string(args.order)) + + else: + + def score(fdsys): + with open(args.ref) as fdref: + scorer = bleu.Scorer( + bleu.BleuConfig( + pad=dict.pad(), + eos=dict.eos(), + unk=dict.unk(), + ) + ) + for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)): + sys_tok = dict.encode_line(sys_tok) + ref_tok = dict.encode_line(ref_tok) + scorer.add(ref_tok, sys_tok) + print(scorer.result_string(args.order)) + + if args.sys == "-": + score(sys.stdin) + else: + with open(args.sys, "r") as f: + score(f) + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq/fairseq_cli/train.py b/fairseq/fairseq_cli/train.py new file mode 100644 index 0000000000000000000000000000000000000000..f771bff654522358e41a710e486cf4c81ce87ea5 --- /dev/null +++ b/fairseq/fairseq_cli/train.py @@ -0,0 +1,581 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Train a new model on one or across multiple GPUs. +""" + +import argparse +import logging +import math +import os +import sys +from typing import Any, Callable, Dict, List, Optional, Tuple + +# We need to setup root logger before importing any fairseq libraries. +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("fairseq_cli.train") + +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf + +from fairseq import checkpoint_utils, options, quantization_utils, tasks, utils +from fairseq.data import data_utils, iterators +from fairseq.data.plasma_utils import PlasmaStore +from fairseq.dataclass.configs import FairseqConfig +from fairseq.dataclass.initialize import add_defaults +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap +from fairseq.distributed import utils as distributed_utils +from fairseq.file_io import PathManager +from fairseq.logging import meters, metrics, progress_bar +from fairseq.model_parallel.megatron_trainer import MegatronTrainer +from fairseq.trainer import Trainer + + +def main(cfg: FairseqConfig) -> None: + if isinstance(cfg, argparse.Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + utils.import_user_module(cfg.common) + add_defaults(cfg) + + if ( + distributed_utils.is_master(cfg.distributed_training) + and "job_logging_cfg" in cfg + ): + # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) + logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg)) + + assert ( + cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None + ), "Must specify batch size either with --max-tokens or --batch-size" + metrics.reset() + + if cfg.common.log_file is not None: + handler = logging.FileHandler(filename=cfg.common.log_file) + logger.addHandler(handler) + + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) + + if distributed_utils.is_master(cfg.distributed_training): + checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir) + + # Print args + logger.info(cfg) + + if cfg.checkpoint.write_checkpoints_asynchronously: + try: + import iopath # noqa: F401 + except ImportError: + logging.exception( + "Asynchronous checkpoint writing is specified but iopath is " + "not installed: `pip install iopath`" + ) + return + + # Setup task, e.g., translation, language modeling, etc. + task = tasks.setup_task(cfg.task) + + assert cfg.criterion, "Please specify criterion to train a model" + + # Build model and criterion + if cfg.distributed_training.ddp_backend == "fully_sharded": + with fsdp_enable_wrap(cfg.distributed_training): + model = fsdp_wrap(task.build_model(cfg.model)) + else: + model = task.build_model(cfg.model) + criterion = task.build_criterion(cfg.criterion) + logger.info(model) + logger.info("task: {}".format(task.__class__.__name__)) + logger.info("model: {}".format(model.__class__.__name__)) + logger.info("criterion: {}".format(criterion.__class__.__name__)) + logger.info( + "num. shared model params: {:,} (num. trained: {:,})".format( + sum( + p.numel() for p in model.parameters() if not getattr(p, "expert", False) + ), + sum( + p.numel() + for p in model.parameters() + if not getattr(p, "expert", False) and p.requires_grad + ), + ) + ) + + logger.info( + "num. expert model params: {} (num. trained: {})".format( + sum(p.numel() for p in model.parameters() if getattr(p, "expert", False)), + sum( + p.numel() + for p in model.parameters() + if getattr(p, "expert", False) and p.requires_grad + ), + ) + ) + + # Load valid dataset (we load training data below, based on the latest checkpoint) + # We load the valid dataset AFTER building the model + if not cfg.dataset.disable_validation: + data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg) + if cfg.dataset.combine_valid_subsets: + task.load_dataset("valid", combine=True, epoch=1) + else: + for valid_sub_split in cfg.dataset.valid_subset.split(","): + task.load_dataset(valid_sub_split, combine=False, epoch=1) + + # (optionally) Configure quantization + if cfg.common.quantization_config_path is not None: + quantizer = quantization_utils.Quantizer( + config_path=cfg.common.quantization_config_path, + max_epoch=cfg.optimization.max_epoch, + max_update=cfg.optimization.max_update, + ) + else: + quantizer = None + + # Build trainer + if cfg.common.model_parallel_size == 1: + trainer = Trainer(cfg, task, model, criterion, quantizer) + else: + trainer = MegatronTrainer(cfg, task, model, criterion) + logger.info( + "training on {} devices (GPUs/TPUs)".format( + cfg.distributed_training.distributed_world_size + ) + ) + logger.info( + "max tokens per device = {} and max sentences per device = {}".format( + cfg.dataset.max_tokens, + cfg.dataset.batch_size, + ) + ) + + # Load the latest checkpoint if one is available and restore the + # corresponding train iterator + extra_state, epoch_itr = checkpoint_utils.load_checkpoint( + cfg.checkpoint, + trainer, + # don't cache epoch iterators for sharded datasets + disable_iterator_cache=task.has_sharded_data("train"), + ) + if cfg.common.tpu: + import torch_xla.core.xla_model as xm + + xm.rendezvous("load_checkpoint") # wait for all workers + + max_epoch = cfg.optimization.max_epoch or math.inf + lr = trainer.get_lr() + + # TODO: a dry run on validation set to pin the memory + valid_subsets = cfg.dataset.valid_subset.split(",") + if not cfg.dataset.disable_validation: + for subset in valid_subsets: + logger.info('begin dry-run validation on "{}" subset'.format(subset)) + itr = trainer.get_valid_iterator(subset).next_epoch_itr( + shuffle=False, set_dataset_epoch=False # use a fixed valid set + ) + if cfg.common.tpu: + itr = utils.tpu_data_loader(itr) + for _ in itr: + pass + # TODO: end of dry run section + + train_meter = meters.StopwatchMeter() + train_meter.start() + while epoch_itr.next_epoch_idx <= max_epoch: + if lr <= cfg.optimization.stop_min_lr: + logger.info( + f"stopping training because current learning rate ({lr}) is smaller " + "than or equal to minimum learning rate " + f"(--stop-min-lr={cfg.optimization.stop_min_lr})" + ) + break + + # train for one epoch + valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) + if should_stop: + break + + # only use first validation loss to update the learning rate + lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) + + epoch_itr = trainer.get_train_iterator( + epoch_itr.next_epoch_idx, + # sharded data: get train iterator for next epoch + load_dataset=task.has_sharded_data("train"), + # don't cache epoch iterators for sharded datasets + disable_iterator_cache=task.has_sharded_data("train"), + ) + train_meter.stop() + logger.info("done training in {:.1f} seconds".format(train_meter.sum)) + + # ioPath implementation to wait for all asynchronous file writes to complete. + if cfg.checkpoint.write_checkpoints_asynchronously: + logger.info( + "ioPath PathManager waiting for all asynchronous checkpoint " + "writes to finish." + ) + PathManager.async_close() + logger.info("ioPath PathManager finished waiting.") + + +def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool: + # skip check if no validation was done in the current epoch + if valid_loss is None: + return False + if cfg.checkpoint.patience <= 0: + return False + + def is_better(a, b): + return a > b if cfg.checkpoint.maximize_best_checkpoint_metric else a < b + + prev_best = getattr(should_stop_early, "best", None) + if prev_best is None or is_better(valid_loss, prev_best): + should_stop_early.best = valid_loss + should_stop_early.num_runs = 0 + return False + else: + should_stop_early.num_runs += 1 + if should_stop_early.num_runs >= cfg.checkpoint.patience: + logger.info( + "early stop since valid performance hasn't improved for last {} runs".format( + cfg.checkpoint.patience + ) + ) + return True + else: + return False + + +@metrics.aggregate("train") +def train( + cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr +) -> Tuple[List[Optional[float]], bool]: + """Train the model for one epoch and return validation losses.""" + # Initialize data iterator + itr = epoch_itr.next_epoch_itr( + fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus, + shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum), + ) + update_freq = ( + cfg.optimization.update_freq[epoch_itr.epoch - 1] + if epoch_itr.epoch <= len(cfg.optimization.update_freq) + else cfg.optimization.update_freq[-1] + ) + itr = iterators.GroupedIterator( + itr, + update_freq, + skip_remainder_batch=cfg.optimization.skip_remainder_batch, + ) + if cfg.common.tpu: + itr = utils.tpu_data_loader(itr) + progress = progress_bar.progress_bar( + itr, + log_format=cfg.common.log_format, + log_file=cfg.common.log_file, + log_interval=cfg.common.log_interval, + epoch=epoch_itr.epoch, + aim_repo=( + cfg.common.aim_repo + if distributed_utils.is_master(cfg.distributed_training) + else None + ), + aim_run_hash=( + cfg.common.aim_run_hash + if distributed_utils.is_master(cfg.distributed_training) + else None + ), + aim_param_checkpoint_dir=cfg.checkpoint.save_dir, + tensorboard_logdir=( + cfg.common.tensorboard_logdir + if distributed_utils.is_master(cfg.distributed_training) + else None + ), + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + wandb_project=( + cfg.common.wandb_project + if distributed_utils.is_master(cfg.distributed_training) + else None + ), + wandb_run_name=os.environ.get( + "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir) + ), + azureml_logging=( + cfg.common.azureml_logging + if distributed_utils.is_master(cfg.distributed_training) + else False + ), + ) + progress.update_config(_flatten_config(cfg)) + + trainer.begin_epoch(epoch_itr.epoch) + + valid_subsets = cfg.dataset.valid_subset.split(",") + should_stop = False + num_updates = trainer.get_num_updates() + logger.info("Start iterating over samples") + for i, samples in enumerate(progress): + with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( + "train_step-%d" % i + ): + log_output = trainer.train_step(samples) + + if log_output is not None: # not OOM, overflow, ... + # log mid-epoch stats + num_updates = trainer.get_num_updates() + if num_updates % cfg.common.log_interval == 0: + stats = get_training_stats(metrics.get_smoothed_values("train_inner")) + progress.log(stats, tag="train_inner", step=num_updates) + + # reset mid-epoch stats after each log interval + # the end-of-epoch stats will still be preserved + metrics.reset_meters("train_inner") + + end_of_epoch = not itr.has_next() + valid_losses, should_stop = validate_and_save( + cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch + ) + + if should_stop: + break + + # log end-of-epoch stats + logger.info("end of epoch {} (average epoch stats below)".format(epoch_itr.epoch)) + stats = get_training_stats(metrics.get_smoothed_values("train")) + progress.print(stats, tag="train", step=num_updates) + + # reset epoch-level meters + metrics.reset_meters("train") + return valid_losses, should_stop + + +def _flatten_config(cfg: DictConfig): + config = OmegaConf.to_container(cfg) + # remove any legacy Namespaces and replace with a single "args" + namespace = None + for k, v in list(config.items()): + if isinstance(v, argparse.Namespace): + namespace = v + del config[k] + if namespace is not None: + config["args"] = vars(namespace) + return config + + +def validate_and_save( + cfg: DictConfig, + trainer: Trainer, + task: tasks.FairseqTask, + epoch_itr, + valid_subsets: List[str], + end_of_epoch: bool, +) -> Tuple[List[Optional[float]], bool]: + num_updates = trainer.get_num_updates() + max_update = cfg.optimization.max_update or math.inf + + # Stopping conditions (and an additional one based on validation loss later + # on) + should_stop = False + if num_updates >= max_update: + should_stop = True + logger.info( + f"Stopping training due to " + f"num_updates: {num_updates} >= max_update: {max_update}" + ) + + training_time_hours = trainer.cumulative_training_time() / (60 * 60) + if ( + cfg.optimization.stop_time_hours > 0 + and training_time_hours > cfg.optimization.stop_time_hours + ): + should_stop = True + logger.info( + f"Stopping training due to " + f"cumulative_training_time: {training_time_hours} > " + f"stop_time_hours: {cfg.optimization.stop_time_hours} hour(s)" + ) + + do_save = ( + (end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0) + or should_stop + or ( + cfg.checkpoint.save_interval_updates > 0 + and num_updates > 0 + and num_updates % cfg.checkpoint.save_interval_updates == 0 + and num_updates >= cfg.dataset.validate_after_updates + ) + ) + do_validate = ( + ( + (not end_of_epoch and do_save) # validate during mid-epoch saves + or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0) + or should_stop + or ( + cfg.dataset.validate_interval_updates > 0 + and num_updates > 0 + and num_updates % cfg.dataset.validate_interval_updates == 0 + ) + ) + and not cfg.dataset.disable_validation + and num_updates >= cfg.dataset.validate_after_updates + ) + + # Validate + valid_losses = [None] + if do_validate: + valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets) + + should_stop |= should_stop_early(cfg, valid_losses[0]) + + # Save checkpoint + if do_save or should_stop: + cp_path = checkpoint_utils.save_checkpoint( + cfg.checkpoint, trainer, epoch_itr, valid_losses[0] + ) + if cp_path is not None and hasattr(task, "post_save"): + task.post_save(cp_path, num_updates) + + return valid_losses, should_stop + + +def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]: + stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0) + return stats + + +def validate( + cfg: DictConfig, + trainer: Trainer, + task: tasks.FairseqTask, + epoch_itr, + subsets: List[str], +) -> List[Optional[float]]: + """Evaluate the model on the validation set(s) and return the losses.""" + + if cfg.dataset.fixed_validation_seed is not None: + # set fixed seed for every validation + utils.set_torch_seed(cfg.dataset.fixed_validation_seed) + + trainer.begin_valid_epoch(epoch_itr.epoch) + valid_losses = [] + for subset_idx, subset in enumerate(subsets): + logger.info('begin validation on "{}" subset'.format(subset)) + + # Initialize data iterator + itr = trainer.get_valid_iterator(subset).next_epoch_itr( + shuffle=False, set_dataset_epoch=False # use a fixed valid set + ) + if cfg.common.tpu: + itr = utils.tpu_data_loader(itr) + progress = progress_bar.progress_bar( + itr, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + epoch=epoch_itr.epoch, + prefix=f"valid on '{subset}' subset", + aim_repo=( + cfg.common.aim_repo + if distributed_utils.is_master(cfg.distributed_training) + else None + ), + aim_run_hash=( + cfg.common.aim_run_hash + if distributed_utils.is_master(cfg.distributed_training) + else None + ), + aim_param_checkpoint_dir=cfg.checkpoint.save_dir, + tensorboard_logdir=( + cfg.common.tensorboard_logdir + if distributed_utils.is_master(cfg.distributed_training) + else None + ), + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + wandb_project=( + cfg.common.wandb_project + if distributed_utils.is_master(cfg.distributed_training) + else None + ), + wandb_run_name=os.environ.get( + "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir) + ), + ) + + # create a new root metrics aggregator so validation metrics + # don't pollute other aggregators (e.g., train meters) + with metrics.aggregate(new_root=True) as agg: + for i, sample in enumerate(progress): + if ( + cfg.dataset.max_valid_steps is not None + and i > cfg.dataset.max_valid_steps + ): + break + trainer.valid_step(sample) + + # log validation stats + # only tracking the best metric on the 1st validation subset + tracking_best = subset_idx == 0 + stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values(), tracking_best) + + if hasattr(task, "post_validate"): + task.post_validate(trainer.get_model(), stats, agg) + + progress.print(stats, tag=subset, step=trainer.get_num_updates()) + + valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric]) + return valid_losses + + +def get_valid_stats( + cfg: DictConfig, + trainer: Trainer, + stats: Dict[str, Any], + tracking_best: bool, +) -> Dict[str, Any]: + stats["num_updates"] = trainer.get_num_updates() + if tracking_best and hasattr(checkpoint_utils.save_checkpoint, "best"): + key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric) + best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min + stats[key] = best_function( + checkpoint_utils.save_checkpoint.best, + stats[cfg.checkpoint.best_checkpoint_metric], + ) + return stats + + +def cli_main( + modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None +) -> None: + parser = options.get_training_parser() + args = options.parse_args_and_arch(parser, modify_parser=modify_parser) + + cfg = convert_namespace_to_omegaconf(args) + + if cfg.common.use_plasma_view: + server = PlasmaStore(path=cfg.common.plasma_path) + logger.info( + f"Started plasma server pid {server.server.pid} {cfg.common.plasma_path}" + ) + + if args.profile: + with torch.cuda.profiler.profile(): + with torch.autograd.profiler.emit_nvtx(): + distributed_utils.call_main(cfg, main) + else: + distributed_utils.call_main(cfg, main) + + # if cfg.common.use_plasma_view: + # server.server.kill() + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq/fairseq_cli/validate.py b/fairseq/fairseq_cli/validate.py new file mode 100644 index 0000000000000000000000000000000000000000..4617b6d542ae944bf7e795df26adc52d3954569c --- /dev/null +++ b/fairseq/fairseq_cli/validate.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import sys +from argparse import Namespace +from itertools import chain + +import torch +from omegaconf import DictConfig + +from fairseq import checkpoint_utils, distributed_utils, options, utils +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.logging import metrics, progress_bar +from fairseq.utils import reset_logging + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("fairseq_cli.validate") + + +def main(cfg: DictConfig, override_args=None): + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + utils.import_user_module(cfg.common) + + reset_logging() + + assert ( + cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None + ), "Must specify batch size either with --max-tokens or --batch-size" + + use_fp16 = cfg.common.fp16 + use_cuda = torch.cuda.is_available() and not cfg.common.cpu + + if use_cuda: + torch.cuda.set_device(cfg.distributed_training.device_id) + + if cfg.distributed_training.distributed_world_size > 1: + data_parallel_world_size = distributed_utils.get_data_parallel_world_size() + data_parallel_rank = distributed_utils.get_data_parallel_rank() + else: + data_parallel_world_size = 1 + data_parallel_rank = 0 + + if override_args is not None: + overrides = vars(override_args) + overrides.update(eval(getattr(override_args, "model_overrides", "{}"))) + else: + overrides = None + + # Load ensemble + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( + [cfg.common_eval.path], + arg_overrides=overrides, + suffix=cfg.checkpoint.checkpoint_suffix, + ) + model = models[0] + + # Move models to GPU + for model in models: + model.eval() + if use_fp16: + model.half() + if use_cuda: + model.cuda() + + # Print args + logger.info(saved_cfg) + + # Build criterion + criterion = task.build_criterion(saved_cfg.criterion) + criterion.eval() + + for subset in cfg.dataset.valid_subset.split(","): + try: + task.load_dataset(subset, combine=False, epoch=1, task_cfg=saved_cfg.task) + dataset = task.dataset(subset) + except KeyError: + raise Exception("Cannot find dataset: " + subset) + + # Initialize data iterator + itr = task.get_batch_iterator( + dataset=dataset, + max_tokens=cfg.dataset.max_tokens, + max_sentences=cfg.dataset.batch_size, + max_positions=utils.resolve_max_positions( + task.max_positions(), + *[m.max_positions() for m in models], + ), + ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, + seed=cfg.common.seed, + num_shards=data_parallel_world_size, + shard_id=data_parallel_rank, + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, + ).next_epoch_itr(shuffle=False) + progress = progress_bar.progress_bar( + itr, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + prefix=f"valid on '{subset}' subset", + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + ) + + log_outputs = [] + for i, sample in enumerate(progress): + sample = utils.move_to_cuda(sample) if use_cuda else sample + _loss, _sample_size, log_output = task.valid_step(sample, model, criterion) + progress.log(log_output, step=i) + log_outputs.append(log_output) + + if data_parallel_world_size > 1: + log_outputs = distributed_utils.all_gather_list( + log_outputs, + max_size=cfg.common.all_gather_list_size, + group=distributed_utils.get_data_parallel_group(), + ) + log_outputs = list(chain.from_iterable(log_outputs)) + + with metrics.aggregate() as agg: + task.reduce_metrics(log_outputs, criterion) + log_output = agg.get_smoothed_values() + + progress.print(log_output, tag=subset, step=i) + + +def cli_main(): + parser = options.get_validation_parser() + args = options.parse_args_and_arch(parser) + + # only override args that are explicitly given on the command line + override_parser = options.get_validation_parser() + override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True) + + distributed_utils.call_main( + convert_namespace_to_omegaconf(args), main, override_args=override_args + ) + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/__init__.py b/fairseq/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4884f5bdcbc3584a16a4c3daa0aec2e8ee8e7f90 --- /dev/null +++ b/fairseq/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +__version__ = "0.1" diff --git a/fairseq/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/config.py b/fairseq/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/config.py new file mode 100644 index 0000000000000000000000000000000000000000..91926c4abc5d6c84f1817444bf677a971268927e --- /dev/null +++ b/fairseq/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/config.py @@ -0,0 +1,23 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from dataclasses import dataclass, field + +from hydra.core.config_store import ConfigStore + +from hydra_plugins.hydra_submitit_launcher.config import SlurmQueueConf + + +@dataclass +class DependencySubmititConf(SlurmQueueConf): + """Slurm configuration overrides and specific parameters""" + + _target_: str = ( + "hydra_plugins.dependency_submitit_launcher.launcher.DependencySubmititLauncher" + ) + + +ConfigStore.instance().store( + group="hydra/launcher", + name="dependency_submitit_slurm", + node=DependencySubmititConf(), + provider="dependency_submitit_slurm", +) diff --git a/fairseq/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/launcher.py b/fairseq/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/launcher.py new file mode 100644 index 0000000000000000000000000000000000000000..b3fcf79e1703eaa71e6e0a8a46e9617fa5757f55 --- /dev/null +++ b/fairseq/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/launcher.py @@ -0,0 +1,121 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import logging +import os +import subprocess +from pathlib import Path +from typing import Any, List, Sequence + +from hydra.core.singleton import Singleton +from hydra.core.utils import JobReturn, filter_overrides +from omegaconf import OmegaConf + +log = logging.getLogger(__name__) + +from .config import DependencySubmititConf +from hydra_plugins.hydra_submitit_launcher.submitit_launcher import BaseSubmititLauncher + + +class DependencySubmititLauncher(BaseSubmititLauncher): + _EXECUTOR = "slurm" + + def launch( + self, job_overrides: Sequence[Sequence[str]], initial_job_idx: int + ) -> Sequence[JobReturn]: + + # lazy import to ensure plugin discovery remains fast + import submitit + + assert self.config is not None + + num_jobs = len(job_overrides) + assert num_jobs > 0 + + next_script = None + + for jo in job_overrides: + if next_script is None: + for item in jo: + if "next_script=" in item: + next_script = item + break + assert ( + next_script is not None + ), "job overrides must contain +next_script=path/to/next/script" + jo.remove(next_script) + + idx = next_script.find("=") + next_script = next_script[idx + 1 :] + + params = self.params + # build executor + init_params = {"folder": self.params["submitit_folder"]} + specific_init_keys = {"max_num_timeout"} + + init_params.update( + **{ + f"{self._EXECUTOR}_{x}": y + for x, y in params.items() + if x in specific_init_keys + } + ) + init_keys = specific_init_keys | {"submitit_folder"} + executor = submitit.AutoExecutor(cluster=self._EXECUTOR, **init_params) + + # specify resources/parameters + baseparams = set(OmegaConf.structured(DependencySubmititConf).keys()) + params = { + x if x in baseparams else f"{self._EXECUTOR}_{x}": y + for x, y in params.items() + if x not in init_keys + } + executor.update_parameters(**params) + + log.info( + f"Submitit '{self._EXECUTOR}' sweep output dir : " + f"{self.config.hydra.sweep.dir}" + ) + sweep_dir = Path(str(self.config.hydra.sweep.dir)) + sweep_dir.mkdir(parents=True, exist_ok=True) + if "mode" in self.config.hydra.sweep: + mode = int(str(self.config.hydra.sweep.mode), 8) + os.chmod(sweep_dir, mode=mode) + + job_params: List[Any] = [] + for idx, overrides in enumerate(job_overrides): + idx = initial_job_idx + idx + lst = " ".join(filter_overrides(overrides)) + log.info(f"\t#{idx} : {lst}") + job_params.append( + ( + list(overrides), + "hydra.sweep.dir", + idx, + f"job_id_for_{idx}", + Singleton.get_state(), + ) + ) + + jobs = executor.map_array(self, *zip(*job_params)) + + for j, jp in zip(jobs, job_params): + job_id = str(j.job_id) + task_id = "0" if "_" not in job_id else job_id.split("_")[1] + sweep_config = self.config_loader.load_sweep_config(self.config, jp[0]) + dir = sweep_config.hydra.sweep.dir + + dir = ( + dir.replace("[", "") + .replace("]", "") + .replace("{", "") + .replace("}", "") + .replace(",", "_") + .replace("'", "") + .replace('"', "") + ) + + subprocess.call( + [next_script, job_id, task_id, dir], + shell=False, + ) + + return [j.results()[0] for j in jobs] diff --git a/fairseq/hydra_plugins/dependency_submitit_launcher/setup.py b/fairseq/hydra_plugins/dependency_submitit_launcher/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..bf795462bdf01851b7b3d90a864c312403d627ff --- /dev/null +++ b/fairseq/hydra_plugins/dependency_submitit_launcher/setup.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# type: ignore +from pathlib import Path + +from read_version import read_version +from setuptools import find_namespace_packages, setup + +setup( + name="dependency-submitit-launcher", + version=read_version("hydra_plugins/dependency_submitit_launcher", "__init__.py"), + author="Alexei Baevski", + author_email="abaevski@fb.com", + description="Dependency-supporting Submitit Launcher for Hydra apps", + packages=find_namespace_packages(include=["hydra_plugins.*"]), + classifiers=[ + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Operating System :: MacOS", + "Operating System :: POSIX :: Linux", + "Development Status :: 4 - Beta", + ], + install_requires=[ + "hydra-core>=1.0.4", + "submitit>=1.0.0", + ], + include_package_data=True, +) diff --git a/fairseq/scripts/__init__.py b/fairseq/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fairseq/scripts/average_checkpoints.py b/fairseq/scripts/average_checkpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..49f4f9d912d3ae43271f0bc28aa608062c445c4d --- /dev/null +++ b/fairseq/scripts/average_checkpoints.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import collections +import os +import re + +import torch + +from fairseq.file_io import PathManager + + +def average_checkpoints(inputs): + """Loads checkpoints from inputs and returns a model with averaged weights. + + Args: + inputs: An iterable of string paths of checkpoints to load from. + + Returns: + A dict of string keys mapping to various values. The 'model' key + from the returned dict should correspond to an OrderedDict mapping + string parameter names to torch Tensors. + """ + params_dict = collections.OrderedDict() + params_keys = None + new_state = None + num_models = len(inputs) + + for fpath in inputs: + with PathManager.open(fpath, "rb") as f: + state = torch.load( + f, + map_location=( + lambda s, _: torch.serialization.default_restore_location(s, "cpu") + ), + ) + # Copies over the settings from the first checkpoint + if new_state is None: + new_state = state + + model_params = state["model"] + + model_params_keys = list(model_params.keys()) + if params_keys is None: + params_keys = model_params_keys + elif params_keys != model_params_keys: + raise KeyError( + "For checkpoint {}, expected list of params: {}, " + "but found: {}".format(f, params_keys, model_params_keys) + ) + + for k in params_keys: + p = model_params[k] + if isinstance(p, torch.HalfTensor): + p = p.float() + if k not in params_dict: + params_dict[k] = p.clone() + # NOTE: clone() is needed in case of p is a shared parameter + else: + params_dict[k] += p + + averaged_params = collections.OrderedDict() + for k, v in params_dict.items(): + averaged_params[k] = v + if averaged_params[k].is_floating_point(): + averaged_params[k].div_(num_models) + else: + averaged_params[k] //= num_models + new_state["model"] = averaged_params + return new_state + + +def last_n_checkpoints(paths, n, update_based, upper_bound=None): + assert len(paths) == 1 + path = paths[0] + if update_based: + pt_regexp = re.compile(r"checkpoint_\d+_(\d+)\.pt") + else: + pt_regexp = re.compile(r"checkpoint(\d+)\.pt") + files = PathManager.ls(path) + + entries = [] + for f in files: + m = pt_regexp.fullmatch(f) + if m is not None: + sort_key = int(m.group(1)) + if upper_bound is None or sort_key <= upper_bound: + entries.append((sort_key, m.group(0))) + if len(entries) < n: + raise Exception( + "Found {} checkpoint files but need at least {}", len(entries), n + ) + return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]] + + +def main(): + parser = argparse.ArgumentParser( + description="Tool to average the params of input checkpoints to " + "produce a new checkpoint", + ) + # fmt: off + parser.add_argument('--inputs', required=True, nargs='+', + help='Input checkpoint file paths.') + parser.add_argument('--output', required=True, metavar='FILE', + help='Write the new checkpoint containing the averaged weights to this path.') + num_group = parser.add_mutually_exclusive_group() + num_group.add_argument('--num-epoch-checkpoints', type=int, + help='if set, will try to find checkpoints with names checkpoint_xx.pt in the ' + 'path specified by input, and average last this many of them.') + num_group.add_argument('--num-update-checkpoints', type=int, + help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by' + ' input, and average last this many of them.') + num_group.add_argument('--num-best-checkpoints', type=int, default=0, + help='if set, will try to find checkpoints with names checkpoint_best_ee_xx.pt in the path specified by' + ' input, and average last this many of them.') + parser.add_argument('--checkpoint-upper-bound', type=int, + help='when using --num-epoch-checkpoints, this will set an upper bound on which epoch to use, ' + 'when using --num-update-checkpoints, this will set an upper bound on which update to use' + 'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be' + ' averaged.' + 'e.g., with --num-update-checkpoints=10 --checkpoint-upper-bound=50000, checkpoints 40500-50000 would' + ' be averaged assuming --save-interval-updates 500' + ) + # fmt: on + args = parser.parse_args() + print(args) + + num = None + is_update_based = False + if args.num_update_checkpoints is not None: + num = args.num_update_checkpoints + is_update_based = True + elif args.num_epoch_checkpoints is not None: + num = args.num_epoch_checkpoints + + assert args.checkpoint_upper_bound is None or ( + args.num_epoch_checkpoints is not None + or args.num_update_checkpoints is not None + ), "--checkpoint-upper-bound requires --num-epoch-checkpoints or --num-update-checkpoints" + assert ( + args.num_epoch_checkpoints is None or args.num_update_checkpoints is None + ), "Cannot combine --num-epoch-checkpoints and --num-update-checkpoints" + + if num is not None: + args.inputs = last_n_checkpoints( + args.inputs, + num, + is_update_based, + upper_bound=args.checkpoint_upper_bound, + ) + print("averaging checkpoints: ", args.inputs) + + if args.num_best_checkpoints > 0: + args.inputs = list( + sorted( + args.inputs, + key=lambda x: float( + os.path.basename(x).split("_")[-1].replace(".pt", "") + ), + ) + ) + args.inputs = args.inputs[: args.num_best_checkpoints] + for path in args.inputs: + print(os.path.basename(path)) + new_state = average_checkpoints(args.inputs) + with PathManager.open(args.output, "wb") as f: + torch.save(new_state, f) + print("Finished writing averaged checkpoint to {}".format(args.output)) + + +if __name__ == "__main__": + main() diff --git a/fairseq/scripts/build_sym_alignment.py b/fairseq/scripts/build_sym_alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..0ca5c18f7bd4b0fbf58b203793506ca395466129 --- /dev/null +++ b/fairseq/scripts/build_sym_alignment.py @@ -0,0 +1,97 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Use this script in order to build symmetric alignments for your translation +dataset. +This script depends on fast_align and mosesdecoder tools. You will need to +build those before running the script. +fast_align: + github: http://github.com/clab/fast_align + instructions: follow the instructions in README.md +mosesdecoder: + github: http://github.com/moses-smt/mosesdecoder + instructions: http://www.statmt.org/moses/?n=Development.GetStarted +The script produces the following files under --output_dir: + text.joined - concatenation of lines from the source_file and the + target_file. + align.forward - forward pass of fast_align. + align.backward - backward pass of fast_align. + aligned.sym_heuristic - symmetrized alignment. +""" + +import argparse +import os +from itertools import zip_longest + + +def main(): + parser = argparse.ArgumentParser(description="symmetric alignment builer") + # fmt: off + parser.add_argument('--fast_align_dir', + help='path to fast_align build directory') + parser.add_argument('--mosesdecoder_dir', + help='path to mosesdecoder root directory') + parser.add_argument('--sym_heuristic', + help='heuristic to use for symmetrization', + default='grow-diag-final-and') + parser.add_argument('--source_file', + help='path to a file with sentences ' + 'in the source language') + parser.add_argument('--target_file', + help='path to a file with sentences ' + 'in the target language') + parser.add_argument('--output_dir', + help='output directory') + # fmt: on + args = parser.parse_args() + + fast_align_bin = os.path.join(args.fast_align_dir, "fast_align") + symal_bin = os.path.join(args.mosesdecoder_dir, "bin", "symal") + sym_fast_align_bin = os.path.join( + args.mosesdecoder_dir, "scripts", "ems", "support", "symmetrize-fast-align.perl" + ) + + # create joined file + joined_file = os.path.join(args.output_dir, "text.joined") + with open(args.source_file, "r", encoding="utf-8") as src, open( + args.target_file, "r", encoding="utf-8" + ) as tgt: + with open(joined_file, "w", encoding="utf-8") as joined: + for s, t in zip_longest(src, tgt): + print("{} ||| {}".format(s.strip(), t.strip()), file=joined) + + bwd_align_file = os.path.join(args.output_dir, "align.backward") + + # run forward alignment + fwd_align_file = os.path.join(args.output_dir, "align.forward") + fwd_fast_align_cmd = "{FASTALIGN} -i {JOINED} -d -o -v > {FWD}".format( + FASTALIGN=fast_align_bin, JOINED=joined_file, FWD=fwd_align_file + ) + assert os.system(fwd_fast_align_cmd) == 0 + + # run backward alignment + bwd_align_file = os.path.join(args.output_dir, "align.backward") + bwd_fast_align_cmd = "{FASTALIGN} -i {JOINED} -d -o -v -r > {BWD}".format( + FASTALIGN=fast_align_bin, JOINED=joined_file, BWD=bwd_align_file + ) + assert os.system(bwd_fast_align_cmd) == 0 + + # run symmetrization + sym_out_file = os.path.join(args.output_dir, "aligned") + sym_cmd = "{SYMFASTALIGN} {FWD} {BWD} {SRC} {TGT} {OUT} {HEURISTIC} {SYMAL}".format( + SYMFASTALIGN=sym_fast_align_bin, + FWD=fwd_align_file, + BWD=bwd_align_file, + SRC=args.source_file, + TGT=args.target_file, + OUT=sym_out_file, + HEURISTIC=args.sym_heuristic, + SYMAL=symal_bin, + ) + assert os.system(sym_cmd) == 0 + + +if __name__ == "__main__": + main() diff --git a/fairseq/scripts/check_installation.py b/fairseq/scripts/check_installation.py new file mode 100644 index 0000000000000000000000000000000000000000..e5a9d9dd46fd457a70cd56f55e94449207ae505e --- /dev/null +++ b/fairseq/scripts/check_installation.py @@ -0,0 +1,36 @@ +from pathlib import Path +import os + +cwd = Path(".").resolve() +print("running 'check_installation.py' from:", cwd) + +# Old versions of numpy/torch can prevent loading the .so files +import torch + +print("torch:", torch.__version__) +import numpy + +print("numpy:", numpy.__version__) + +import fairseq + +print("Fairseq installed at:", fairseq.__file__) +import fairseq.criterions +import fairseq.dataclass.configs + +import _imp + +print("Should load following .so suffixes:", _imp.extension_suffixes()) + +so_files = list(Path(fairseq.__file__).parent.glob("*.so")) +so_files.extend(Path(fairseq.__file__).parent.glob("data/*.so")) +print("Found following .so files:") +for so_file in so_files: + print(f"- {so_file}") + +from fairseq import libbleu + +print("Found libbleu at", libbleu.__file__) +from fairseq.data import data_utils_fast + +print("Found data_utils_fast at", data_utils_fast.__file__) diff --git a/fairseq/scripts/compare_namespaces.py b/fairseq/scripts/compare_namespaces.py new file mode 100644 index 0000000000000000000000000000000000000000..bc24db624f8db36f546c263ba3a806dae6d466bf --- /dev/null +++ b/fairseq/scripts/compare_namespaces.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python +"""Helper script to compare two argparse.Namespace objects.""" + +from argparse import Namespace # noqa + + +def main(): + + ns1 = eval(input("Namespace 1: ")) + ns2 = eval(input("Namespace 2: ")) + + def keys(ns): + ks = set() + for k in dir(ns): + if not k.startswith("_"): + ks.add(k) + return ks + + k1 = keys(ns1) + k2 = keys(ns2) + + def print_keys(ks, ns1, ns2=None): + for k in ks: + if ns2 is None: + print("{}\t{}".format(k, getattr(ns1, k, None))) + else: + print( + "{}\t{}\t{}".format(k, getattr(ns1, k, None), getattr(ns2, k, None)) + ) + + print("Keys unique to namespace 1:") + print_keys(k1 - k2, ns1) + print() + + print("Keys unique to namespace 2:") + print_keys(k2 - k1, ns2) + print() + + print("Overlapping keys with different values:") + ks = [k for k in k1 & k2 if getattr(ns1, k, "None") != getattr(ns2, k, "None")] + print_keys(ks, ns1, ns2) + print() + + +if __name__ == "__main__": + main() diff --git a/fairseq/scripts/compound_split_bleu.sh b/fairseq/scripts/compound_split_bleu.sh new file mode 100644 index 0000000000000000000000000000000000000000..1972fddcebff9a43a70bcf14c287175c68f60e3f --- /dev/null +++ b/fairseq/scripts/compound_split_bleu.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +if [ $# -ne 1 ]; then + echo "usage: $0 GENERATE_PY_OUTPUT" + exit 1 +fi + +GEN=$1 + +SYS=$GEN.sys +REF=$GEN.ref + +if [ $(tail -n 1 $GEN | grep BLEU | wc -l) -ne 1 ]; then + echo "not done generating" + exit +fi + +grep ^H $GEN | awk -F '\t' '{print $NF}' | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $SYS +grep ^T $GEN | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $REF +fairseq-score --sys $SYS --ref $REF diff --git a/fairseq/scripts/constraints/extract.py b/fairseq/scripts/constraints/extract.py new file mode 100644 index 0000000000000000000000000000000000000000..437b373856966e568ca93c13ebbd1417291e49da --- /dev/null +++ b/fairseq/scripts/constraints/extract.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Extracts random constraints from reference files.""" + +import argparse +import random +import sys + + +def get_phrase(words, index, length): + assert index < len(words) - length + 1 + phr = " ".join(words[index : index + length]) + for i in range(index, index + length): + words.pop(index) + return phr + + +def main(args): + + if args.seed: + random.seed(args.seed) + + for line in sys.stdin: + constraints = [] + + def add_constraint(constraint): + constraints.append(constraint) + + source = line.rstrip() + if "\t" in line: + source, target = line.split("\t") + if args.add_sos: + target = f" {target}" + if args.add_eos: + target = f"{target} " + + if len(target.split()) >= args.len: + words = [target] + + num = args.number + + choices = {} + for i in range(num): + if len(words) == 0: + break + segmentno = random.choice(range(len(words))) + segment = words.pop(segmentno) + tokens = segment.split() + phrase_index = random.choice(range(len(tokens))) + choice = " ".join( + tokens[phrase_index : min(len(tokens), phrase_index + args.len)] + ) + for j in range( + phrase_index, min(len(tokens), phrase_index + args.len) + ): + tokens.pop(phrase_index) + if phrase_index > 0: + words.append(" ".join(tokens[0:phrase_index])) + if phrase_index + 1 < len(tokens): + words.append(" ".join(tokens[phrase_index:])) + choices[target.find(choice)] = choice + + # mask out with spaces + target = target.replace(choice, " " * len(choice), 1) + + for key in sorted(choices.keys()): + add_constraint(choices[key]) + + print(source, *constraints, sep="\t") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--number", "-n", type=int, default=1, help="number of phrases") + parser.add_argument("--len", "-l", type=int, default=1, help="phrase length") + parser.add_argument( + "--add-sos", default=False, action="store_true", help="add token" + ) + parser.add_argument( + "--add-eos", default=False, action="store_true", help="add token" + ) + parser.add_argument("--seed", "-s", default=0, type=int) + args = parser.parse_args() + + main(args) diff --git a/fairseq/scripts/constraints/validate.py b/fairseq/scripts/constraints/validate.py new file mode 100644 index 0000000000000000000000000000000000000000..d531ad9f39b1df42c98fe8f26ad61fe53a9ac0c5 --- /dev/null +++ b/fairseq/scripts/constraints/validate.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys + + +"""Reads in a fairseq output file, and verifies that the constraints +(C- lines) are present in the output (the first H- line). Assumes that +constraints are listed prior to the first hypothesis. +""" + +constraints = [] +found = 0 +total = 0 +for line in sys.stdin: + if line.startswith("C-"): + constraints.append(line.rstrip().split("\t")[1]) + elif line.startswith("H-"): + text = line.split("\t")[2] + + for constraint in constraints: + total += 1 + if constraint in text: + found += 1 + else: + print(f"No {constraint} in {text}", file=sys.stderr) + + constraints = [] + +print(f"Found {found} / {total} = {100 * found / total:.1f}%") diff --git a/fairseq/scripts/convert_dictionary.lua b/fairseq/scripts/convert_dictionary.lua new file mode 100644 index 0000000000000000000000000000000000000000..14ee8c997f642c8ff196617c2dcd0584037a60c4 --- /dev/null +++ b/fairseq/scripts/convert_dictionary.lua @@ -0,0 +1,34 @@ +-- Copyright (c) Facebook, Inc. and its affiliates. +-- +-- This source code is licensed under the MIT license found in the +-- LICENSE file in the root directory of this source tree. +-- +-- Usage: convert_dictionary.lua +require 'fairseq' +require 'torch' +require 'paths' + +if #arg < 1 then + print('usage: convert_dictionary.lua ') + os.exit(1) +end +if not paths.filep(arg[1]) then + print('error: file does not exit: ' .. arg[1]) + os.exit(1) +end + +dict = torch.load(arg[1]) +dst = paths.basename(arg[1]):gsub('.th7', '.txt') +assert(dst:match('.txt$')) + +f = io.open(dst, 'w') +for idx, symbol in ipairs(dict.index_to_symbol) do + if idx > dict.cutoff then + break + end + f:write(symbol) + f:write(' ') + f:write(dict.index_to_freq[idx]) + f:write('\n') +end +f:close() diff --git a/fairseq/scripts/convert_model.lua b/fairseq/scripts/convert_model.lua new file mode 100644 index 0000000000000000000000000000000000000000..61b92139294fb90a25989ebd2ee52a765fb278a2 --- /dev/null +++ b/fairseq/scripts/convert_model.lua @@ -0,0 +1,108 @@ +-- Copyright (c) Facebook, Inc. and its affiliates. +-- +-- This source code is licensed under the MIT license found in the +-- LICENSE file in the root directory of this source tree. +-- +-- Usage: convert_model.lua +require 'torch' +local fairseq = require 'fairseq' + +model = torch.load(arg[1]) + +function find_weight_norm(container, module) + for _, wn in ipairs(container:listModules()) do + if torch.type(wn) == 'nn.WeightNorm' and wn.modules[1] == module then + return wn + end + end +end + +function push_state(dict, key, module) + if torch.type(module) == 'nn.Linear' then + local wn = find_weight_norm(model.module, module) + assert(wn) + dict[key .. '.weight_v'] = wn.v:float() + dict[key .. '.weight_g'] = wn.g:float() + elseif torch.type(module) == 'nn.TemporalConvolutionTBC' then + local wn = find_weight_norm(model.module, module) + assert(wn) + local v = wn.v:float():view(wn.viewOut):transpose(2, 3) + dict[key .. '.weight_v'] = v + dict[key .. '.weight_g'] = wn.g:float():view(module.weight:size(3), 1, 1) + else + dict[key .. '.weight'] = module.weight:float() + end + if module.bias then + dict[key .. '.bias'] = module.bias:float() + end +end + +encoder_dict = {} +decoder_dict = {} +combined_dict = {} + +function encoder_state(encoder) + luts = encoder:findModules('nn.LookupTable') + push_state(encoder_dict, 'embed_tokens', luts[1]) + push_state(encoder_dict, 'embed_positions', luts[2]) + + fcs = encoder:findModules('nn.Linear') + assert(#fcs >= 2) + local nInputPlane = fcs[1].weight:size(1) + push_state(encoder_dict, 'fc1', table.remove(fcs, 1)) + push_state(encoder_dict, 'fc2', table.remove(fcs, #fcs)) + + for i, module in ipairs(encoder:findModules('nn.TemporalConvolutionTBC')) do + push_state(encoder_dict, 'convolutions.' .. tostring(i - 1), module) + if nInputPlane ~= module.weight:size(3) / 2 then + push_state(encoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1)) + end + nInputPlane = module.weight:size(3) / 2 + end + assert(#fcs == 0) +end + +function decoder_state(decoder) + luts = decoder:findModules('nn.LookupTable') + push_state(decoder_dict, 'embed_tokens', luts[1]) + push_state(decoder_dict, 'embed_positions', luts[2]) + + fcs = decoder:findModules('nn.Linear') + local nInputPlane = fcs[1].weight:size(1) + push_state(decoder_dict, 'fc1', table.remove(fcs, 1)) + push_state(decoder_dict, 'fc2', fcs[#fcs - 1]) + push_state(decoder_dict, 'fc3', fcs[#fcs]) + + table.remove(fcs, #fcs) + table.remove(fcs, #fcs) + + for i, module in ipairs(decoder:findModules('nn.TemporalConvolutionTBC')) do + if nInputPlane ~= module.weight:size(3) / 2 then + push_state(decoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1)) + end + nInputPlane = module.weight:size(3) / 2 + + local prefix = 'attention.' .. tostring(i - 1) + push_state(decoder_dict, prefix .. '.in_projection', table.remove(fcs, 1)) + push_state(decoder_dict, prefix .. '.out_projection', table.remove(fcs, 1)) + push_state(decoder_dict, 'convolutions.' .. tostring(i - 1), module) + end + assert(#fcs == 0) +end + + +_encoder = model.module.modules[2] +_decoder = model.module.modules[3] + +encoder_state(_encoder) +decoder_state(_decoder) + +for k, v in pairs(encoder_dict) do + combined_dict['encoder.' .. k] = v +end +for k, v in pairs(decoder_dict) do + combined_dict['decoder.' .. k] = v +end + + +torch.save('state_dict.t7', combined_dict) diff --git a/fairseq/scripts/count_docs.py b/fairseq/scripts/count_docs.py new file mode 100644 index 0000000000000000000000000000000000000000..58d85af85e91377a34dbd01f7674436152fd08e8 --- /dev/null +++ b/fairseq/scripts/count_docs.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Count the number of documents and average number of lines and tokens per +document in a large file. Documents should be separated by a single empty line. +""" + +import argparse +import gzip +import sys + +import numpy as np + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input") + parser.add_argument("--gzip", action="store_true") + args = parser.parse_args() + + def gopen(): + if args.gzip: + return gzip.open(args.input, "r") + else: + return open(args.input, "r", encoding="utf-8") + + num_lines = [] + num_toks = [] + with gopen() as h: + num_docs = 1 + num_lines_in_doc = 0 + num_toks_in_doc = 0 + for i, line in enumerate(h): + if len(line.strip()) == 0: # empty line indicates new document + num_docs += 1 + num_lines.append(num_lines_in_doc) + num_toks.append(num_toks_in_doc) + num_lines_in_doc = 0 + num_toks_in_doc = 0 + else: + num_lines_in_doc += 1 + num_toks_in_doc += len(line.rstrip().split()) + if i % 1000000 == 0: + print(i, file=sys.stderr, end="", flush=True) + elif i % 100000 == 0: + print(".", file=sys.stderr, end="", flush=True) + print(file=sys.stderr, flush=True) + + print("found {} docs".format(num_docs)) + print("average num lines per doc: {}".format(np.mean(num_lines))) + print("average num toks per doc: {}".format(np.mean(num_toks))) + + +if __name__ == "__main__": + main() diff --git a/fairseq/scripts/read_binarized.py b/fairseq/scripts/read_binarized.py new file mode 100644 index 0000000000000000000000000000000000000000..a414095d03fb022a6753e816fc8bfd80e11db24d --- /dev/null +++ b/fairseq/scripts/read_binarized.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + +from fairseq.data import Dictionary, data_utils, indexed_dataset + + +def get_parser(): + parser = argparse.ArgumentParser( + description="writes text from binarized file to stdout" + ) + # fmt: off + parser.add_argument('--dataset-impl', help='dataset implementation', + choices=indexed_dataset.get_available_dataset_impl()) + parser.add_argument('--dict', metavar='FP', help='dictionary containing known words', default=None) + parser.add_argument('--input', metavar='FP', required=True, help='binarized file to read') + # fmt: on + + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + dictionary = Dictionary.load(args.dict) if args.dict is not None else None + dataset = data_utils.load_indexed_dataset( + args.input, + dictionary, + dataset_impl=args.dataset_impl, + default="lazy", + ) + + for tensor_line in dataset: + if dictionary is None: + line = " ".join([str(int(x)) for x in tensor_line]) + else: + line = dictionary.string(tensor_line) + + print(line) + + +if __name__ == "__main__": + main() diff --git a/fairseq/scripts/rm_pt.py b/fairseq/scripts/rm_pt.py new file mode 100644 index 0000000000000000000000000000000000000000..6cd063d21f0610fa7c42c2cfb2ee8af7c9c78677 --- /dev/null +++ b/fairseq/scripts/rm_pt.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import re +import shutil +import sys + + +pt_regexp = re.compile(r"checkpoint(\d+|_\d+_\d+|_[a-z]+)\.pt") +pt_regexp_epoch_based = re.compile(r"checkpoint(\d+)\.pt") +pt_regexp_update_based = re.compile(r"checkpoint_\d+_(\d+)\.pt") + + +def parse_checkpoints(files): + entries = [] + for f in files: + m = pt_regexp_epoch_based.fullmatch(f) + if m is not None: + entries.append((int(m.group(1)), m.group(0))) + else: + m = pt_regexp_update_based.fullmatch(f) + if m is not None: + entries.append((int(m.group(1)), m.group(0))) + return entries + + +def last_n_checkpoints(files, n): + entries = parse_checkpoints(files) + return [x[1] for x in sorted(entries, reverse=True)[:n]] + + +def every_n_checkpoints(files, n): + entries = parse_checkpoints(files) + return [x[1] for x in sorted(sorted(entries)[::-n])] + + +def main(): + parser = argparse.ArgumentParser( + description=( + "Recursively delete checkpoint files from `root_dir`, " + "but preserve checkpoint_best.pt and checkpoint_last.pt" + ) + ) + parser.add_argument("root_dirs", nargs="*") + parser.add_argument( + "--save-last", type=int, default=0, help="number of last checkpoints to save" + ) + parser.add_argument( + "--save-every", type=int, default=0, help="interval of checkpoints to save" + ) + parser.add_argument( + "--preserve-test", + action="store_true", + help="preserve checkpoints in dirs that start with test_ prefix (default: delete them)", + ) + parser.add_argument( + "--delete-best", action="store_true", help="delete checkpoint_best.pt" + ) + parser.add_argument( + "--delete-last", action="store_true", help="delete checkpoint_last.pt" + ) + parser.add_argument( + "--no-dereference", action="store_true", help="don't dereference symlinks" + ) + args = parser.parse_args() + + files_to_desymlink = [] + files_to_preserve = [] + files_to_delete = [] + for root_dir in args.root_dirs: + for root, _subdirs, files in os.walk(root_dir): + if args.save_last > 0: + to_save = last_n_checkpoints(files, args.save_last) + else: + to_save = [] + if args.save_every > 0: + to_save += every_n_checkpoints(files, args.save_every) + for file in files: + if not pt_regexp.fullmatch(file): + continue + full_path = os.path.join(root, file) + if ( + not os.path.basename(root).startswith("test_") or args.preserve_test + ) and ( + (file == "checkpoint_last.pt" and not args.delete_last) + or (file == "checkpoint_best.pt" and not args.delete_best) + or file in to_save + ): + if os.path.islink(full_path) and not args.no_dereference: + files_to_desymlink.append(full_path) + else: + files_to_preserve.append(full_path) + else: + files_to_delete.append(full_path) + + if len(files_to_desymlink) == 0 and len(files_to_delete) == 0: + print("Nothing to do.") + sys.exit(0) + + files_to_desymlink = sorted(files_to_desymlink) + files_to_preserve = sorted(files_to_preserve) + files_to_delete = sorted(files_to_delete) + + print("Operations to perform (in order):") + if len(files_to_desymlink) > 0: + for file in files_to_desymlink: + print(" - preserve (and dereference symlink): " + file) + if len(files_to_preserve) > 0: + for file in files_to_preserve: + print(" - preserve: " + file) + if len(files_to_delete) > 0: + for file in files_to_delete: + print(" - delete: " + file) + while True: + resp = input("Continue? (Y/N): ") + if resp.strip().lower() == "y": + break + elif resp.strip().lower() == "n": + sys.exit(0) + + print("Executing...") + if len(files_to_desymlink) > 0: + for file in files_to_desymlink: + realpath = os.path.realpath(file) + print("rm " + file) + os.remove(file) + print("cp {} {}".format(realpath, file)) + shutil.copyfile(realpath, file) + if len(files_to_delete) > 0: + for file in files_to_delete: + print("rm " + file) + os.remove(file) + + +if __name__ == "__main__": + main() diff --git a/fairseq/scripts/sacrebleu.sh b/fairseq/scripts/sacrebleu.sh new file mode 100644 index 0000000000000000000000000000000000000000..c10bf2b76ea032deabab6f5c9d8a3e1e884f1642 --- /dev/null +++ b/fairseq/scripts/sacrebleu.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +if [ $# -ne 4 ]; then + echo "usage: $0 TESTSET SRCLANG TGTLANG GEN" + exit 1 +fi + +TESTSET=$1 +SRCLANG=$2 +TGTLANG=$3 + +GEN=$4 + +if ! command -v sacremoses &> /dev/null +then + echo "sacremoses could not be found, please install with: pip install sacremoses" + exit +fi + +grep ^H $GEN \ +| sed 's/^H\-//' \ +| sort -n -k 1 \ +| cut -f 3 \ +| sacremoses detokenize \ +> $GEN.sorted.detok + +sacrebleu --test-set $TESTSET --language-pair "${SRCLANG}-${TGTLANG}" < $GEN.sorted.detok diff --git a/fairseq/scripts/shard_docs.py b/fairseq/scripts/shard_docs.py new file mode 100644 index 0000000000000000000000000000000000000000..97232c3c845ee01dc5ab627388934cc0f9588280 --- /dev/null +++ b/fairseq/scripts/shard_docs.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Split a large file into shards while respecting document boundaries. Documents +should be separated by a single empty line. +""" + +import argparse +import contextlib + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input") + parser.add_argument("--num-shards", type=int) + args = parser.parse_args() + + assert args.num_shards is not None and args.num_shards > 1 + + with open(args.input, "r", encoding="utf-8") as h: + with contextlib.ExitStack() as stack: + outputs = [ + stack.enter_context( + open(args.input + ".shard" + str(i), "w", encoding="utf-8") + ) + for i in range(args.num_shards) + ] + + doc = [] + first_doc = [True] * args.num_shards + + def output_doc(i): + if not first_doc[i]: + outputs[i].write("\n") + first_doc[i] = False + for line in doc: + outputs[i].write(line) + doc.clear() + + num_docs = 0 + for line in h: + if line.strip() == "": # empty line indicates new document + output_doc(num_docs % args.num_shards) + num_docs += 1 + else: + doc.append(line) + output_doc(num_docs % args.num_shards) + + +if __name__ == "__main__": + main() diff --git a/fairseq/scripts/split_train_valid_docs.py b/fairseq/scripts/split_train_valid_docs.py new file mode 100644 index 0000000000000000000000000000000000000000..ff159785284a13b44626b207d84430c592acaf8f --- /dev/null +++ b/fairseq/scripts/split_train_valid_docs.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Split a large file into a train and valid set while respecting document +boundaries. Documents should be separated by a single empty line. +""" + +import argparse +import random +import sys + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input") + parser.add_argument("sample_output", help="train output file") + parser.add_argument("remainder_output", help="valid output file") + parser.add_argument("-k", type=int, help="remainder size") + parser.add_argument( + "--lines", action="store_true", help="split lines instead of docs" + ) + args = parser.parse_args() + + assert args.k is not None + + sample = [] + remainder = [] + num_docs = [0] + + def update_sample(doc): + if len(sample) < args.k: + sample.append(doc.copy()) + else: + i = num_docs[0] + j = random.randrange(i + 1) + if j < args.k: + remainder.append(sample[j]) + sample[j] = doc.copy() + else: + remainder.append(doc.copy()) + num_docs[0] += 1 + doc.clear() + + with open(args.input, "r", encoding="utf-8") as h: + doc = [] + for i, line in enumerate(h): + if line.strip() == "": # empty line indicates new document + update_sample(doc) + else: + doc.append(line) + if args.lines: + update_sample(doc) + if i % 1000000 == 0: + print(i, file=sys.stderr, end="", flush=True) + elif i % 100000 == 0: + print(".", file=sys.stderr, end="", flush=True) + if len(doc) > 0: + update_sample(doc) + print(file=sys.stderr, flush=True) + + assert len(sample) == args.k + + with open(args.sample_output, "w", encoding="utf-8") as out: + first = True + for doc in sample: + if not first and not args.lines: + out.write("\n") + first = False + for line in doc: + out.write(line) + + with open(args.remainder_output, "w", encoding="utf-8") as out: + first = True + for doc in remainder: + if not first and not args.lines: + out.write("\n") + first = False + for line in doc: + out.write(line) + + +if __name__ == "__main__": + main() diff --git a/fairseq/scripts/spm_decode.py b/fairseq/scripts/spm_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..7d7b68b240265924601ca6a738ed3d7b4b8e9cda --- /dev/null +++ b/fairseq/scripts/spm_decode.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import argparse + +import sentencepiece as spm + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", required=True, help="sentencepiece model to use for decoding" + ) + parser.add_argument("--input", required=True, help="input file to decode") + parser.add_argument("--input_format", choices=["piece", "id"], default="piece") + args = parser.parse_args() + + sp = spm.SentencePieceProcessor() + sp.Load(args.model) + + if args.input_format == "piece": + + def decode(input): + return "".join(sp.DecodePieces(input)) + + elif args.input_format == "id": + + def decode(input): + return "".join(sp.DecodeIds(input)) + + else: + raise NotImplementedError + + def tok2int(tok): + # remap reference-side (represented as <>) to 0 + return int(tok) if tok != "<>" else 0 + + with open(args.input, "r", encoding="utf-8") as h: + for line in h: + if args.input_format == "id": + print(decode(list(map(tok2int, line.rstrip().split())))) + elif args.input_format == "piece": + print(decode(line.rstrip().split())) + + +if __name__ == "__main__": + main() diff --git a/fairseq/scripts/spm_encode.py b/fairseq/scripts/spm_encode.py new file mode 100644 index 0000000000000000000000000000000000000000..f91e0bb728a33448c1415aee6036ac9d0feac11f --- /dev/null +++ b/fairseq/scripts/spm_encode.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import argparse +import contextlib +import sys + +import sentencepiece as spm + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", required=True, help="sentencepiece model to use for encoding" + ) + parser.add_argument( + "--inputs", nargs="+", default=["-"], help="input files to filter/encode" + ) + parser.add_argument( + "--outputs", nargs="+", default=["-"], help="path to save encoded outputs" + ) + parser.add_argument("--output_format", choices=["piece", "id"], default="piece") + parser.add_argument( + "--min-len", + type=int, + metavar="N", + help="filter sentence pairs with fewer than N tokens", + ) + parser.add_argument( + "--max-len", + type=int, + metavar="N", + help="filter sentence pairs with more than N tokens", + ) + args = parser.parse_args() + + assert len(args.inputs) == len( + args.outputs + ), "number of input and output paths should match" + + sp = spm.SentencePieceProcessor() + sp.Load(args.model) + + if args.output_format == "piece": + + def encode(input): + return sp.EncodeAsPieces(input) + + elif args.output_format == "id": + + def encode(input): + return list(map(str, sp.EncodeAsIds(input))) + + else: + raise NotImplementedError + + if args.min_len is not None or args.max_len is not None: + + def valid(line): + return (args.min_len is None or len(line) >= args.min_len) and ( + args.max_len is None or len(line) <= args.max_len + ) + + else: + + def valid(lines): + return True + + with contextlib.ExitStack() as stack: + inputs = [ + stack.enter_context(open(input, "r", encoding="utf-8")) + if input != "-" + else sys.stdin + for input in args.inputs + ] + outputs = [ + stack.enter_context(open(output, "w", encoding="utf-8")) + if output != "-" + else sys.stdout + for output in args.outputs + ] + + stats = { + "num_empty": 0, + "num_filtered": 0, + } + + def encode_line(line): + line = line.strip() + if len(line) > 0: + line = encode(line) + if valid(line): + return line + else: + stats["num_filtered"] += 1 + else: + stats["num_empty"] += 1 + return None + + for i, lines in enumerate(zip(*inputs), start=1): + enc_lines = list(map(encode_line, lines)) + if not any(enc_line is None for enc_line in enc_lines): + for enc_line, output_h in zip(enc_lines, outputs): + print(" ".join(enc_line), file=output_h) + if i % 10000 == 0: + print("processed {} lines".format(i), file=sys.stderr) + + print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr) + print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr) + + +if __name__ == "__main__": + main() diff --git a/fairseq/scripts/spm_train.py b/fairseq/scripts/spm_train.py new file mode 100644 index 0000000000000000000000000000000000000000..9db668fd4166a860198784990de68ea26157995d --- /dev/null +++ b/fairseq/scripts/spm_train.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import sys + +import sentencepiece as spm + + +if __name__ == "__main__": + spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:])) diff --git a/fairseq/scripts/test_fsdp.sh b/fairseq/scripts/test_fsdp.sh new file mode 100644 index 0000000000000000000000000000000000000000..1f428a035e4474427ded991f8e8307ea59f61f69 --- /dev/null +++ b/fairseq/scripts/test_fsdp.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +rm -rf fsdp_dummy +mkdir -p fsdp_dummy +CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train /private/home/sshleifer/data-bin/stories_mmap \ + --ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \ + --cpu-offload --checkpoint-activations \ + --task language_modeling --tokens-per-sample 256 --batch-size 8 \ + --arch transformer_lm_gpt2_tiny \ + --optimizer cpu_adam --adam-betas "(0.9,0.98)" \ + --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \ + --max-update 5 --log-format json --log-interval 1 \ + --save-interval-updates 5 --save-dir fsdp_dummy --disable-validation \ + --restore-file x.pt "$@" + +# Now we try to load the checkpoint +CUDA_VISIBLE_DEVICES=0,1 fairseq-train /private/home/sshleifer/data-bin/stories_mmap \ + --ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \ + --cpu-offload --checkpoint-activations \ + --task language_modeling --tokens-per-sample 256 --batch-size 8 \ + --arch transformer_lm_gpt2_tiny \ + --optimizer cpu_adam --adam-betas "(0.9,0.98)" \ + --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \ + --max-update 2 --log-format json --log-interval 1 \ + --save-interval-updates 2 --save-dir fsdp_dummy diff --git a/fairseq/tests/__init__.py b/fairseq/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fairseq/tests/tasks/test_masked_lm.py b/fairseq/tests/tasks/test_masked_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..215cd355b0778393ce1bbf64645cf34692acc86b --- /dev/null +++ b/fairseq/tests/tasks/test_masked_lm.py @@ -0,0 +1,78 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import unittest +from tempfile import TemporaryDirectory + +from fairseq.binarizer import FileBinarizer, VocabularyDatasetBinarizer +from fairseq.tasks.masked_lm import MaskedLMConfig, MaskedLMTask +from tests.utils import build_vocab, make_data + + +class TestMaskedLM(unittest.TestCase): + def test_masks_tokens(self): + with TemporaryDirectory() as dirname: + + # prep input file + raw_file = os.path.join(dirname, "raw") + data = make_data(out_file=raw_file) + vocab = build_vocab(data) + + # binarize + binarizer = VocabularyDatasetBinarizer(vocab, append_eos=False) + split = "train" + bin_file = os.path.join(dirname, split) + FileBinarizer.multiprocess_dataset( + input_file=raw_file, + binarizer=binarizer, + dataset_impl="mmap", + vocab_size=len(vocab), + output_prefix=bin_file, + ) + + # setup task + cfg = MaskedLMConfig( + data=dirname, + seed=42, + mask_prob=0.5, # increasing the odds of masking + random_token_prob=0, # avoiding random tokens for exact match + leave_unmasked_prob=0, # always masking for exact match + ) + task = MaskedLMTask(cfg, binarizer.dict) + + original_dataset = task._load_dataset_split(bin_file, 1, False) + + # load datasets + task.load_dataset(split) + masked_dataset = task.dataset(split) + + mask_index = task.source_dictionary.index("") + iterator = task.get_batch_iterator( + dataset=masked_dataset, + max_tokens=65_536, + max_positions=4_096, + ).next_epoch_itr(shuffle=False) + for batch in iterator: + for sample in range(len(batch)): + net_input = batch["net_input"] + masked_src_tokens = net_input["src_tokens"][sample] + masked_src_length = net_input["src_lengths"][sample] + masked_tgt_tokens = batch["target"][sample] + + sample_id = batch["id"][sample] + original_tokens = original_dataset[sample_id] + original_tokens = original_tokens.masked_select( + masked_src_tokens[:masked_src_length] == mask_index + ) + masked_tokens = masked_tgt_tokens.masked_select( + masked_tgt_tokens != task.source_dictionary.pad() + ) + + assert masked_tokens.equal(original_tokens) + + +if __name__ == "__main__": + unittest.main() diff --git a/fairseq/tests/tasks/test_span_masked_lm.py b/fairseq/tests/tasks/test_span_masked_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..d289cf843e2dc805403a79c31dfeb256d4ab9115 --- /dev/null +++ b/fairseq/tests/tasks/test_span_masked_lm.py @@ -0,0 +1,106 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import unittest +from tempfile import TemporaryDirectory + +from fairseq import options +from fairseq.binarizer import FileBinarizer, VocabularyDatasetBinarizer +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.tasks.span_masked_lm import SpanMaskedLMTask +from tests.utils import build_vocab, make_data + + +class TestSpanMaskedLM(unittest.TestCase): + def test_masks_token_spans(self): + with TemporaryDirectory() as dirname: + + # prep input file + raw_file = os.path.join(dirname, "raw") + data = make_data(out_file=raw_file) + vocab = build_vocab(data) + + # binarize + binarizer = VocabularyDatasetBinarizer(vocab, append_eos=False) + split = "train" + bin_file = os.path.join(dirname, split) + dataset_impl = "mmap" + + FileBinarizer.multiprocess_dataset( + input_file=raw_file, + binarizer=binarizer, + dataset_impl=dataset_impl, + vocab_size=len(vocab), + output_prefix=bin_file, + ) + + # adding sentinel tokens + for i in range(100): + vocab.add_symbol(f"") + + # setup task + train_args = options.parse_args_and_arch( + options.get_training_parser(), + [ + "--task", + "span_masked_lm", + "--arch", + "bart_base", + "--seed", + "42", + dirname, + ], + ) + cfg = convert_namespace_to_omegaconf(train_args) + task = SpanMaskedLMTask(cfg.task, binarizer.dict) + + # load datasets + original_dataset = task._load_dataset_split(bin_file, 1, False) + task.load_dataset(split) + masked_dataset = task.dataset(split) + + iterator = task.get_batch_iterator( + dataset=masked_dataset, + max_tokens=65_536, + max_positions=4_096, + ).next_epoch_itr(shuffle=False) + num_tokens = len(vocab) + for batch in iterator: + for sample in range(len(batch)): + sample_id = batch["id"][sample] + original_tokens = original_dataset[sample_id] + masked_src_tokens = batch["net_input"]["src_tokens"][sample] + masked_src_length = batch["net_input"]["src_lengths"][sample] + masked_tgt_tokens = batch["target"][sample] + + original_offset = 0 + masked_tgt_offset = 0 + extra_id_token = len(vocab) - 1 + for masked_src_token in masked_src_tokens[:masked_src_length]: + if masked_src_token == extra_id_token: + assert ( + masked_src_token == masked_tgt_tokens[masked_tgt_offset] + ) + extra_id_token -= 1 + masked_tgt_offset += 1 + while ( + original_offset < len(original_tokens) + and masked_tgt_tokens[masked_tgt_offset] + != extra_id_token + ): + assert ( + original_tokens[original_offset] + == masked_tgt_tokens[masked_tgt_offset] + ) + original_offset += 1 + masked_tgt_offset += 1 + else: + assert original_tokens[original_offset] == masked_src_token + original_offset += 1 + + +if __name__ == "__main__": + unittest.main() diff --git a/fairseq/tests/test_activation_checkpointing.py b/fairseq/tests/test_activation_checkpointing.py new file mode 100644 index 0000000000000000000000000000000000000000..647a9572886f8aff09a4aadc0b21e1d5817ff38e --- /dev/null +++ b/fairseq/tests/test_activation_checkpointing.py @@ -0,0 +1,79 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +import torch.nn as nn +from fairseq.modules.checkpoint_activations import checkpoint_wrapper +from torch.utils.checkpoint import checkpoint + + +class Model(nn.Module): + def __init__( + self, use_pytorch_checkpoint=False, use_fairseq_checkpoint=False, **kwargs + ): + super().__init__() + torch.manual_seed(0) + self.use_pytorch_checkpoint = use_pytorch_checkpoint + self.ffn = nn.Sequential( + nn.Linear(32, 128), + # add a Dropout layer to test RNG save/restore + nn.Dropout(p=0.5), + nn.Linear(128, 32), + ) + if use_fairseq_checkpoint: + self.ffn = checkpoint_wrapper(self.ffn, **kwargs) + self.out = nn.Linear(32, 1) + + def forward(self, x): + if self.use_pytorch_checkpoint: + x = checkpoint(self.ffn, x) + else: + x = self.ffn(x) + return self.out(x) + + +class TestActivationCheckpointing(unittest.TestCase): + def _test_checkpoint_wrapper(self, device, log_memory_usage=False): + def get_loss_and_gnorm(model): + torch.manual_seed(1) + input = torch.rand(2, 16, 32).requires_grad_(True).to(device) + model.zero_grad() + loss = model(input).sum() + loss.backward() + gnorm = torch.norm( + torch.stack([torch.norm(p.grad.detach()) for p in model.parameters()]) + ) + return {"loss": loss, "gnorm": gnorm} + + model = Model().to(device) + no_cpt = get_loss_and_gnorm(model) + + model = Model(use_pytorch_checkpoint=True).to(device) + pyt_cpt = get_loss_and_gnorm(model) + torch.testing.assert_allclose(no_cpt["loss"], pyt_cpt["loss"]) + torch.testing.assert_allclose(no_cpt["gnorm"], pyt_cpt["gnorm"]) + + model = Model(use_fairseq_checkpoint=True).to(device) + fairseq_cpt = get_loss_and_gnorm(model) + torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt["loss"]) + torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt["gnorm"]) + + model = Model(use_fairseq_checkpoint=True, offload_to_cpu=True).to(device) + fairseq_cpt_offload = get_loss_and_gnorm(model) + torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt_offload["loss"]) + torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt_offload["gnorm"]) + + def test_checkpoint_wrapper_cpu(self): + self._test_checkpoint_wrapper(device=torch.device("cpu")) + + @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") + def test_checkpoint_wrapper_cuda(self): + self._test_checkpoint_wrapper(device=torch.device("cuda")) + + +if __name__ == "__main__": + unittest.main() diff --git a/fairseq/tests/test_amp_optimizer.py b/fairseq/tests/test_amp_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..4d6073a926513dc07bd8421b766ee9eb8cc94333 --- /dev/null +++ b/fairseq/tests/test_amp_optimizer.py @@ -0,0 +1,75 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import copy +import unittest + +import torch +from torch.cuda.amp import GradScaler, autocast + +from fairseq.optim import build_optimizer + + +@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") +class TestGradientScalingAMP(unittest.TestCase): + def setUp(self): + self.x = torch.tensor([2.0]).cuda().half() + weight = 3.0 + bias = 5.0 + self.error = 1.0 + self.target = torch.tensor([self.x * weight + bias + self.error]).cuda() + self.loss_fn = torch.nn.L1Loss() + + self.model = torch.nn.Linear(1, 1) + self.model.weight.data = torch.tensor([[weight]]) + self.model.bias.data = torch.tensor([bias]) + self.model.cuda() + self.params = list(self.model.parameters()) + + self.namespace_dls = argparse.Namespace( + optimizer="adam", + lr=[0.1], + adam_betas="(0.9, 0.999)", + adam_eps=1e-8, + weight_decay=0.0, + threshold_loss_scale=1, + min_loss_scale=1e-4, + ) + self.scaler = GradScaler( + init_scale=1, + growth_interval=1, + ) + + def run_iter(self, model, params, optimizer): + optimizer.zero_grad() + with autocast(): + y = model(self.x) + loss = self.loss_fn(y, self.target) + self.scaler.scale(loss).backward() + self.assertEqual(loss, torch.tensor(1.0, device="cuda:0", dtype=torch.float16)) + + self.scaler.unscale_(optimizer) + grad_norm = optimizer.clip_grad_norm(0) + self.assertAlmostEqual(grad_norm.item(), 2.2361, 4) + + self.scaler.step(optimizer) + self.scaler.update() + self.assertEqual( + model.weight, + torch.tensor([[3.1]], device="cuda:0", requires_grad=True), + ) + self.assertEqual( + model.bias, + torch.tensor([5.1], device="cuda:0", requires_grad=True), + ) + self.assertEqual(self.scaler.get_scale(), 2.0) + + def test_automatic_mixed_precision(self): + model = copy.deepcopy(self.model) + params = list(model.parameters()) + optimizer = build_optimizer(self.namespace_dls, params) + + self.run_iter(model, params, optimizer) diff --git a/fairseq/tests/test_average_checkpoints.py b/fairseq/tests/test_average_checkpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..f348b56b869372d8434fe03f13324d78e9093fa2 --- /dev/null +++ b/fairseq/tests/test_average_checkpoints.py @@ -0,0 +1,134 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import collections +import os +import shutil +import tempfile +import unittest + +import numpy as np +import torch +from scripts.average_checkpoints import average_checkpoints +from torch import nn + + +class ModelWithSharedParameter(nn.Module): + def __init__(self): + super(ModelWithSharedParameter, self).__init__() + self.embedding = nn.Embedding(1000, 200) + self.FC1 = nn.Linear(200, 200) + self.FC2 = nn.Linear(200, 200) + # tie weight in FC2 to FC1 + self.FC2.weight = nn.Parameter(self.FC1.weight) + self.FC2.bias = nn.Parameter(self.FC1.bias) + + self.relu = nn.ReLU() + + def forward(self, input): + return self.FC2(self.ReLU(self.FC1(input))) + self.FC1(input) + + +class TestAverageCheckpoints(unittest.TestCase): + def test_average_checkpoints(self): + params_0 = collections.OrderedDict( + [ + ("a", torch.DoubleTensor([100.0])), + ("b", torch.FloatTensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])), + ("c", torch.IntTensor([7, 8, 9])), + ] + ) + params_1 = collections.OrderedDict( + [ + ("a", torch.DoubleTensor([1.0])), + ("b", torch.FloatTensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])), + ("c", torch.IntTensor([2, 2, 2])), + ] + ) + params_avg = collections.OrderedDict( + [ + ("a", torch.DoubleTensor([50.5])), + ("b", torch.FloatTensor([[1.0, 1.5, 2.0], [2.5, 3.0, 3.5]])), + # We expect truncation for integer division + ("c", torch.IntTensor([4, 5, 5])), + ] + ) + + fd_0, path_0 = tempfile.mkstemp() + fd_1, path_1 = tempfile.mkstemp() + torch.save(collections.OrderedDict([("model", params_0)]), path_0) + torch.save(collections.OrderedDict([("model", params_1)]), path_1) + + output = average_checkpoints([path_0, path_1])["model"] + + os.close(fd_0) + os.remove(path_0) + os.close(fd_1) + os.remove(path_1) + + for (k_expected, v_expected), (k_out, v_out) in zip( + params_avg.items(), output.items() + ): + self.assertEqual( + k_expected, + k_out, + "Key mismatch - expected {} but found {}. " + "(Expected list of keys: {} vs actual list of keys: {})".format( + k_expected, k_out, params_avg.keys(), output.keys() + ), + ) + np.testing.assert_allclose( + v_expected.numpy(), + v_out.numpy(), + err_msg="Tensor value mismatch for key {}".format(k_expected), + ) + + def test_average_checkpoints_with_shared_parameters(self): + def _construct_model_with_shared_parameters(path, value): + m = ModelWithSharedParameter() + nn.init.constant_(m.FC1.weight, value) + torch.save({"model": m.state_dict()}, path) + return m + + tmpdir = tempfile.mkdtemp() + paths = [] + path = os.path.join(tmpdir, "m1.pt") + m1 = _construct_model_with_shared_parameters(path, 1.0) + paths.append(path) + + path = os.path.join(tmpdir, "m2.pt") + m2 = _construct_model_with_shared_parameters(path, 2.0) + paths.append(path) + + path = os.path.join(tmpdir, "m3.pt") + m3 = _construct_model_with_shared_parameters(path, 3.0) + paths.append(path) + + new_model = average_checkpoints(paths) + self.assertTrue( + torch.equal( + new_model["model"]["embedding.weight"], + (m1.embedding.weight + m2.embedding.weight + m3.embedding.weight) / 3.0, + ) + ) + + self.assertTrue( + torch.equal( + new_model["model"]["FC1.weight"], + (m1.FC1.weight + m2.FC1.weight + m3.FC1.weight) / 3.0, + ) + ) + + self.assertTrue( + torch.equal( + new_model["model"]["FC2.weight"], + (m1.FC2.weight + m2.FC2.weight + m3.FC2.weight) / 3.0, + ) + ) + shutil.rmtree(tmpdir) + + +if __name__ == "__main__": + unittest.main() diff --git a/fairseq/tests/test_backtranslation_dataset.py b/fairseq/tests/test_backtranslation_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..dffc3b49387dfdc046ea23d7db179377040b7cbc --- /dev/null +++ b/fairseq/tests/test_backtranslation_dataset.py @@ -0,0 +1,123 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import tests.utils as test_utils +import torch +from fairseq.data import ( + BacktranslationDataset, + LanguagePairDataset, + TransformEosDataset, +) +from fairseq.sequence_generator import SequenceGenerator + + +class TestBacktranslationDataset(unittest.TestCase): + def setUp(self): + ( + self.tgt_dict, + self.w1, + self.w2, + self.src_tokens, + self.src_lengths, + self.model, + ) = test_utils.sequence_generator_setup() + + dummy_src_samples = self.src_tokens + + self.tgt_dataset = test_utils.TestDataset(data=dummy_src_samples) + self.cuda = torch.cuda.is_available() + + def _backtranslation_dataset_helper( + self, + remove_eos_from_input_src, + remove_eos_from_output_src, + ): + tgt_dataset = LanguagePairDataset( + src=self.tgt_dataset, + src_sizes=self.tgt_dataset.sizes, + src_dict=self.tgt_dict, + tgt=None, + tgt_sizes=None, + tgt_dict=None, + ) + + generator = SequenceGenerator( + [self.model], + tgt_dict=self.tgt_dict, + max_len_a=0, + max_len_b=200, + beam_size=2, + unk_penalty=0, + ) + + backtranslation_dataset = BacktranslationDataset( + tgt_dataset=TransformEosDataset( + dataset=tgt_dataset, + eos=self.tgt_dict.eos(), + # remove eos from the input src + remove_eos_from_src=remove_eos_from_input_src, + ), + src_dict=self.tgt_dict, + backtranslation_fn=( + lambda sample: generator.generate([self.model], sample) + ), + output_collater=TransformEosDataset( + dataset=tgt_dataset, + eos=self.tgt_dict.eos(), + # if we remove eos from the input src, then we need to add it + # back to the output tgt + append_eos_to_tgt=remove_eos_from_input_src, + remove_eos_from_src=remove_eos_from_output_src, + ).collater, + cuda=self.cuda, + ) + dataloader = torch.utils.data.DataLoader( + backtranslation_dataset, + batch_size=2, + collate_fn=backtranslation_dataset.collater, + ) + backtranslation_batch_result = next(iter(dataloader)) + + eos, pad, w1, w2 = self.tgt_dict.eos(), self.tgt_dict.pad(), self.w1, self.w2 + + # Note that we sort by src_lengths and add left padding, so actually + # ids will look like: [1, 0] + expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]]) + if remove_eos_from_output_src: + expected_src = expected_src[:, :-1] + expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]]) + generated_src = backtranslation_batch_result["net_input"]["src_tokens"] + tgt_tokens = backtranslation_batch_result["target"] + + self.assertTensorEqual(expected_src, generated_src) + self.assertTensorEqual(expected_tgt, tgt_tokens) + + def test_backtranslation_dataset_no_eos_in_output_src(self): + self._backtranslation_dataset_helper( + remove_eos_from_input_src=False, + remove_eos_from_output_src=True, + ) + + def test_backtranslation_dataset_with_eos_in_output_src(self): + self._backtranslation_dataset_helper( + remove_eos_from_input_src=False, + remove_eos_from_output_src=False, + ) + + def test_backtranslation_dataset_no_eos_in_input_src(self): + self._backtranslation_dataset_helper( + remove_eos_from_input_src=True, + remove_eos_from_output_src=False, + ) + + def assertTensorEqual(self, t1, t2): + self.assertEqual(t1.size(), t2.size(), "size mismatch") + self.assertEqual(t1.ne(t2).long().sum(), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/fairseq/tests/test_binaries.py b/fairseq/tests/test_binaries.py new file mode 100644 index 0000000000000000000000000000000000000000..41d9210e7c44843daad429d13c658cf136820224 --- /dev/null +++ b/fairseq/tests/test_binaries.py @@ -0,0 +1,1915 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import json +import logging +import os +import random +import sys +import tempfile +import unittest +from packaging import version +from io import StringIO +from typing import Dict, List + +import torch + +from fairseq import options +from fairseq_cli import eval_lm, train +from tests.utils import ( + create_dummy_data, + create_laser_data_and_config_json, + generate_main, + preprocess_lm_data, + preprocess_summarization_data, + preprocess_translation_data, + train_language_model, + train_translation_model, +) + +try: + import transformers # noqa + + has_hf_transformers = True +except ImportError: + has_hf_transformers = False + + +class TestTranslation(unittest.TestCase): + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + + def test_fconv(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_fconv") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model(data_dir, "fconv_iwslt_de_en") + generate_main(data_dir) + + def test_raw(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_fconv_raw") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir, ["--dataset-impl", "raw"]) + train_translation_model( + data_dir, "fconv_iwslt_de_en", ["--dataset-impl", "raw"] + ) + generate_main(data_dir, ["--dataset-impl", "raw"]) + + def test_update_freq(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_update_freq") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model( + data_dir, "fconv_iwslt_de_en", ["--update-freq", "3"] + ) + generate_main(data_dir) + + def test_max_positions(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_max_positions") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + with self.assertRaises(Exception) as context: + train_translation_model( + data_dir, + "fconv_iwslt_de_en", + ["--max-target-positions", "5"], + ) + self.assertTrue( + "skip this example with --skip-invalid-size-inputs-valid-test" + in str(context.exception) + ) + train_translation_model( + data_dir, + "fconv_iwslt_de_en", + [ + "--max-target-positions", + "5", + "--skip-invalid-size-inputs-valid-test", + ], + ) + with self.assertRaises(Exception) as context: + generate_main(data_dir) + generate_main(data_dir, ["--skip-invalid-size-inputs-valid-test"]) + + def test_generation(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_sampling") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model(data_dir, "fconv_iwslt_de_en") + generate_main( + data_dir, + [ + "--sampling", + "--temperature", + "2", + "--beam", + "2", + "--nbest", + "2", + ], + ) + generate_main( + data_dir, + [ + "--sampling", + "--sampling-topk", + "3", + "--beam", + "2", + "--nbest", + "2", + ], + ) + generate_main( + data_dir, + [ + "--sampling", + "--sampling-topp", + "0.2", + "--beam", + "2", + "--nbest", + "2", + ], + ) + generate_main( + data_dir, + [ + "--diversity-rate", + "0.5", + "--beam", + "6", + ], + ) + with self.assertRaises(ValueError): + generate_main( + data_dir, + [ + "--diverse-beam-groups", + "4", + "--match-source-len", + ], + ) + generate_main(data_dir, ["--prefix-size", "2"]) + generate_main(data_dir, ["--retain-dropout"]) + + def test_eval_bleu(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_eval_bleu") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model( + data_dir, + "fconv_iwslt_de_en", + [ + "--eval-bleu", + "--eval-bleu-print-samples", + "--eval-bleu-remove-bpe", + "--eval-bleu-detok", + "space", + "--eval-bleu-args", + '{"beam": 4, "min_len": 10}', + ], + ) + + def test_lstm(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_lstm") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model( + data_dir, + "lstm_wiseman_iwslt_de_en", + [ + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--decoder-out-embed-dim", + "8", + ], + ) + generate_main(data_dir) + + def test_lstm_bidirectional(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_lstm_bidirectional") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model( + data_dir, + "lstm", + [ + "--encoder-layers", + "2", + "--encoder-bidirectional", + "--encoder-hidden-size", + "16", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--decoder-out-embed-dim", + "8", + "--decoder-layers", + "2", + ], + ) + generate_main(data_dir) + + def test_transformer(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_transformer") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model( + data_dir, + "transformer_iwslt_de_en", + [ + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + ], + run_validation=True, + ) + generate_main(data_dir) + + def test_multilingual_transformer(self): + # test with all combinations of encoder/decoder lang tokens + encoder_langtok_flags = [ + [], + ["--encoder-langtok", "src"], + ["--encoder-langtok", "tgt"], + ] + decoder_langtok_flags = [[], ["--decoder-langtok"]] + with contextlib.redirect_stdout(StringIO()): + for i in range(len(encoder_langtok_flags)): + for j in range(len(decoder_langtok_flags)): + enc_ltok_flag = encoder_langtok_flags[i] + dec_ltok_flag = decoder_langtok_flags[j] + with tempfile.TemporaryDirectory( + f"test_multilingual_transformer_{i}_{j}" + ) as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model( + data_dir, + arch="multilingual_transformer", + task="multilingual_translation", + extra_flags=[ + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + ] + + enc_ltok_flag + + dec_ltok_flag, + lang_flags=["--lang-pairs", "in-out,out-in"], + run_validation=True, + extra_valid_flags=enc_ltok_flag + dec_ltok_flag, + ) + generate_main( + data_dir, + extra_flags=[ + "--task", + "multilingual_translation", + "--lang-pairs", + "in-out,out-in", + "--source-lang", + "in", + "--target-lang", + "out", + ] + + enc_ltok_flag + + dec_ltok_flag, + ) + + @unittest.skipIf( + sys.platform.lower() == "darwin", "skip latent depth test on MacOS" + ) + def test_multilingual_translation_latent_depth(self): + # test with latent depth in encoder, decoder, or both + encoder_latent_layer = [[], ["--encoder-latent-layer"]] + decoder_latent_layer = [[], ["--decoder-latent-layer"]] + with contextlib.redirect_stdout(StringIO()): + for i in range(len(encoder_latent_layer)): + for j in range(len(decoder_latent_layer)): + if i == 0 and j == 0: + continue + enc_ll_flag = encoder_latent_layer[i] + dec_ll_flag = decoder_latent_layer[j] + with tempfile.TemporaryDirectory( + f"test_multilingual_translation_latent_depth_{i}_{j}" + ) as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data( + data_dir, extra_flags=["--joined-dictionary"] + ) + train_translation_model( + data_dir, + arch="latent_multilingual_transformer", + task="multilingual_translation_latent_depth", + extra_flags=[ + "--user-dir", + "examples/latent_depth/latent_depth_src", + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--share-encoders", + "--share-decoders", + "--sparsity-weight", + "0.1", + ] + + enc_ll_flag + + dec_ll_flag, + lang_flags=["--lang-pairs", "in-out,out-in"], + run_validation=True, + extra_valid_flags=[ + "--user-dir", + "examples/latent_depth/latent_depth_src", + ] + + enc_ll_flag + + dec_ll_flag, + ) + generate_main( + data_dir, + extra_flags=[ + "--user-dir", + "examples/latent_depth/latent_depth_src", + "--task", + "multilingual_translation_latent_depth", + "--lang-pairs", + "in-out,out-in", + "--source-lang", + "in", + "--target-lang", + "out", + ] + + enc_ll_flag + + dec_ll_flag, + ) + + def test_translation_multi_simple_epoch(self): + # test with all combinations of encoder/decoder lang tokens + encoder_langtok_flags = [ + [], + ["--encoder-langtok", "src"], + ["--encoder-langtok", "tgt"], + ] + decoder_langtok_flags = [[], ["--decoder-langtok"]] + with contextlib.redirect_stdout(StringIO()): + for i in range(len(encoder_langtok_flags)): + for j in range(len(decoder_langtok_flags)): + enc_ltok_flag = encoder_langtok_flags[i] + dec_ltok_flag = decoder_langtok_flags[j] + with tempfile.TemporaryDirectory( + f"test_translation_multi_simple_epoch_{i}_{j}" + ) as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data( + data_dir, extra_flags=["--joined-dictionary"] + ) + train_translation_model( + data_dir, + arch="transformer", + task="translation_multi_simple_epoch", + extra_flags=[ + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--sampling-method", + "temperature", + "--sampling-temperature", + "1.5", + "--virtual-epoch-size", + "1000", + ] + + enc_ltok_flag + + dec_ltok_flag, + lang_flags=["--lang-pairs", "in-out,out-in"], + run_validation=True, + extra_valid_flags=enc_ltok_flag + dec_ltok_flag, + ) + generate_main( + data_dir, + extra_flags=[ + "--task", + "translation_multi_simple_epoch", + "--lang-pairs", + "in-out,out-in", + "--source-lang", + "in", + "--target-lang", + "out", + ] + + enc_ltok_flag + + dec_ltok_flag, + ) + + def test_translation_multi_simple_epoch_no_vepoch(self): + # test with all combinations of encoder/decoder lang tokens + with contextlib.redirect_stdout(StringIO()): + enc_ltok_flag = ["--encoder-langtok", "src"] + dec_ltok_flag = ["--decoder-langtok"] + with tempfile.TemporaryDirectory( + "test_translation_multi_simple_epoch_dict" + ) as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir, extra_flags=[]) + train_translation_model( + data_dir, + arch="transformer", + task="translation_multi_simple_epoch", + extra_flags=[ + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--sampling-method", + "temperature", + "--sampling-temperature", + "1.5", + ] + + enc_ltok_flag + + dec_ltok_flag, + lang_flags=["--lang-pairs", "in-out"], + run_validation=True, + extra_valid_flags=enc_ltok_flag + dec_ltok_flag, + ) + generate_main( + data_dir, + extra_flags=[ + "--task", + "translation_multi_simple_epoch", + "--lang-pairs", + "in-out", + "--source-lang", + "in", + "--target-lang", + "out", + ] + + enc_ltok_flag + + dec_ltok_flag, + ) + + def test_translation_multi_simple_epoch_dicts(self): + # test with all combinations of encoder/decoder lang tokens + with contextlib.redirect_stdout(StringIO()): + enc_ltok_flag = ["--encoder-langtok", "src"] + dec_ltok_flag = ["--decoder-langtok"] + with tempfile.TemporaryDirectory( + "test_translation_multi_simple_epoch_dict" + ) as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir, extra_flags=[]) + train_translation_model( + data_dir, + arch="transformer", + task="translation_multi_simple_epoch", + extra_flags=[ + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--sampling-method", + "temperature", + "--sampling-temperature", + "1.5", + "--virtual-epoch-size", + "1000", + ] + + enc_ltok_flag + + dec_ltok_flag, + lang_flags=["--lang-pairs", "in-out"], + run_validation=True, + extra_valid_flags=enc_ltok_flag + dec_ltok_flag, + ) + generate_main( + data_dir, + extra_flags=[ + "--task", + "translation_multi_simple_epoch", + "--lang-pairs", + "in-out", + "--source-lang", + "in", + "--target-lang", + "out", + ] + + enc_ltok_flag + + dec_ltok_flag, + ) + + def test_translation_multi_simple_epoch_src_tgt_dict_spec(self): + # test the specification of explicit --src-dict and --tgt-dict + with contextlib.redirect_stdout(StringIO()): + enc_ltok_flag = ["--encoder-langtok", "src"] + dec_ltok_flag = ["--decoder-langtok"] + with tempfile.TemporaryDirectory( + "test_translation_multi_simple_epoch_dict" + ) as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir, extra_flags=[]) + train_translation_model( + data_dir, + arch="transformer", + task="translation_multi_simple_epoch", + extra_flags=[ + "--source-dict", + f"{data_dir}/dict.in.txt", + "--target-dict", + f"{data_dir}/dict.out.txt", + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--sampling-method", + "temperature", + "--sampling-temperature", + "1.5", + "--virtual-epoch-size", + "1000", + ] + + enc_ltok_flag + + dec_ltok_flag, + lang_flags=["--lang-pairs", "in-out"], + run_validation=True, + extra_valid_flags=enc_ltok_flag + dec_ltok_flag, + ) + generate_main( + data_dir, + extra_flags=[ + "--task", + "translation_multi_simple_epoch", + "--lang-pairs", + "in-out", + "--source-lang", + "in", + "--target-lang", + "out", + ] + + enc_ltok_flag + + dec_ltok_flag, + ) + + def test_transformer_cross_self_attention(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory( + "test_transformer_cross_self_attention" + ) as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model( + data_dir, + "transformer_iwslt_de_en", + [ + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--no-cross-attention", + "--cross-self-attention", + ], + run_validation=True, + ) + generate_main(data_dir, extra_flags=[]) + + @unittest.skipIf( + version.parse(torch.__version__) > version.parse("1.8"), + "skip for latest torch versions", + ) + def test_transformer_pointer_generator(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory( + "test_transformer_pointer_generator" + ) as data_dir: + create_dummy_data(data_dir) + preprocess_summarization_data(data_dir) + train_translation_model( + data_dir, + "transformer_pointer_generator", + extra_flags=[ + "--user-dir", + "examples/pointer_generator/pointer_generator_src", + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--alignment-layer", + "-1", + "--alignment-heads", + "1", + "--source-position-markers", + "0", + ], + run_validation=True, + extra_valid_flags=[ + "--user-dir", + "examples/pointer_generator/pointer_generator_src", + ], + ) + generate_main( + data_dir, + extra_flags=[ + "--user-dir", + "examples/pointer_generator/pointer_generator_src", + ], + ) + + def test_lightconv(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_lightconv") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model( + data_dir, + "lightconv_iwslt_de_en", + [ + "--encoder-conv-type", + "lightweight", + "--decoder-conv-type", + "lightweight", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + ], + ) + generate_main(data_dir) + + def test_dynamicconv(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_dynamicconv") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model( + data_dir, + "lightconv_iwslt_de_en", + [ + "--encoder-conv-type", + "dynamic", + "--decoder-conv-type", + "dynamic", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + ], + ) + generate_main(data_dir) + + def test_cmlm_transformer(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_cmlm_transformer") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir, ["--joined-dictionary"]) + train_translation_model( + data_dir, + "cmlm_transformer", + [ + "--apply-bert-init", + "--criterion", + "nat_loss", + "--noise", + "full_mask", + "--pred-length-offset", + "--length-loss-factor", + "0.1", + ], + task="translation_lev", + ) + generate_main( + data_dir, + [ + "--task", + "translation_lev", + "--iter-decode-max-iter", + "9", + "--iter-decode-eos-penalty", + "0", + "--print-step", + ], + ) + + def test_nonautoregressive_transformer(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory( + "test_nonautoregressive_transformer" + ) as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir, ["--joined-dictionary"]) + train_translation_model( + data_dir, + "nonautoregressive_transformer", + [ + "--apply-bert-init", + "--src-embedding-copy", + "--criterion", + "nat_loss", + "--noise", + "full_mask", + "--pred-length-offset", + "--length-loss-factor", + "0.1", + ], + task="translation_lev", + ) + generate_main( + data_dir, + [ + "--task", + "translation_lev", + "--iter-decode-max-iter", + "0", + "--iter-decode-eos-penalty", + "0", + "--print-step", + ], + ) + + # def test_nat_crf_transformer(self): + # with contextlib.redirect_stdout(StringIO()): + # with tempfile.TemporaryDirectory('test_nat_crf_transformer') as data_dir: + # create_dummy_data(data_dir) + # preprocess_translation_data(data_dir, ['--joined-dictionary']) + # train_translation_model(data_dir, 'nacrf_transformer', [ + # '--apply-bert-init', '--criterion', + # 'nat_loss', '--noise', 'full_mask', '--pred-length-offset', + # '--length-loss-factor', '0.1', + # '--word-ins-loss-factor', '0.5', + # '--crf-lowrank-approx', '1', + # '--crf-beam-approx', '1' + # ], task='translation_lev') + # generate_main(data_dir, [ + # '--task', 'translation_lev', + # '--iter-decode-max-iter', '0', + # '--iter-decode-eos-penalty', '0', + # '--print-step', + # ]) + + def test_iterative_nonautoregressive_transformer(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory( + "test_iterative_nonautoregressive_transformer" + ) as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir, ["--joined-dictionary"]) + train_translation_model( + data_dir, + "iterative_nonautoregressive_transformer", + [ + "--apply-bert-init", + "--src-embedding-copy", + "--criterion", + "nat_loss", + "--noise", + "full_mask", + "--stochastic-approx", + "--dae-ratio", + "0.5", + "--train-step", + "3", + ], + task="translation_lev", + ) + generate_main( + data_dir, + [ + "--task", + "translation_lev", + "--iter-decode-max-iter", + "9", + "--iter-decode-eos-penalty", + "0", + "--print-step", + ], + ) + + def test_insertion_transformer(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_insertion_transformer") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir, ["--joined-dictionary"]) + train_translation_model( + data_dir, + "insertion_transformer", + [ + "--apply-bert-init", + "--criterion", + "nat_loss", + "--noise", + "random_mask", + ], + task="translation_lev", + ) + generate_main( + data_dir, + [ + "--task", + "translation_lev", + "--iter-decode-max-iter", + "9", + "--iter-decode-eos-penalty", + "0", + "--print-step", + ], + ) + + def test_mixture_of_experts(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_moe") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model( + data_dir, + "transformer_iwslt_de_en", + [ + "--task", + "translation_moe", + "--user-dir", + "examples/translation_moe/translation_moe_src", + "--method", + "hMoElp", + "--mean-pool-gating-network", + "--num-experts", + "3", + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + ], + ) + generate_main( + data_dir, + [ + "--task", + "translation_moe", + "--user-dir", + "examples/translation_moe/translation_moe_src", + "--method", + "hMoElp", + "--mean-pool-gating-network", + "--num-experts", + "3", + "--gen-expert", + "0", + ], + ) + + def test_alignment(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_alignment") as data_dir: + create_dummy_data(data_dir, alignment=True) + preprocess_translation_data(data_dir, ["--align-suffix", "align"]) + train_translation_model( + data_dir, + "transformer_align", + [ + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--load-alignments", + "--alignment-layer", + "1", + "--criterion", + "label_smoothed_cross_entropy_with_alignment", + ], + run_validation=True, + ) + generate_main(data_dir) + + def test_laser_lstm(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_laser_lstm") as data_dir: + laser_config_file = create_laser_data_and_config_json(data_dir) + train_translation_model( + laser_config_file.name, + "laser_lstm", + [ + "--user-dir", + "examples/laser/laser_src", + "--weighting-alpha", + "0.3", + "--encoder-bidirectional", + "--encoder-hidden-size", + "512", + "--encoder-layers", + "5", + "--decoder-layers", + "1", + "--encoder-embed-dim", + "320", + "--decoder-embed-dim", + "320", + "--decoder-lang-embed-dim", + "32", + "--save-dir", + data_dir, + "--disable-validation", + ], + task="laser", + lang_flags=[], + ) + + def test_laser_transformer(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_laser_transformer") as data_dir: + laser_config_file = create_laser_data_and_config_json(data_dir) + train_translation_model( + laser_config_file.name, + "laser_transformer", + [ + "--user-dir", + "examples/laser/laser_src", + "--weighting-alpha", + "0.3", + "--encoder-embed-dim", + "320", + "--decoder-embed-dim", + "320", + "--decoder-lang-embed-dim", + "32", + "--save-dir", + data_dir, + "--disable-validation", + ], + task="laser", + lang_flags=[], + ) + + def test_alignment_full_context(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_alignment") as data_dir: + create_dummy_data(data_dir, alignment=True) + preprocess_translation_data(data_dir, ["--align-suffix", "align"]) + train_translation_model( + data_dir, + "transformer_align", + [ + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--load-alignments", + "--alignment-layer", + "1", + "--criterion", + "label_smoothed_cross_entropy_with_alignment", + "--full-context-alignment", + ], + run_validation=True, + ) + generate_main(data_dir) + + def test_transformer_layerdrop(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_transformer_layerdrop") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model( + data_dir, + "transformer_iwslt_de_en", + [ + "--encoder-layers", + "3", + "--decoder-layers", + "3", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--encoder-layerdrop", + "0.01", + "--decoder-layerdrop", + "0.01", + ], + ) + generate_main(data_dir) + generate_main( + data_dir, + [ + "--model-overrides", + "{'encoder_layers_to_keep':'0,2','decoder_layers_to_keep':'1'}", + ], + ) + + +class TestStories(unittest.TestCase): + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + + def test_fconv_self_att_wp(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_fconv_self_att_wp") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + config = [ + "--encoder-layers", + "[(128, 3)] * 2", + "--decoder-layers", + "[(128, 3)] * 2", + "--decoder-attention", + "True", + "--encoder-attention", + "False", + "--gated-attention", + "True", + "--self-attention", + "True", + "--project-input", + "True", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--decoder-out-embed-dim", + "8", + "--multihead-self-attention-nheads", + "2", + ] + train_translation_model(data_dir, "fconv_self_att_wp", config) + generate_main(data_dir) + + # fusion model + os.rename( + os.path.join(data_dir, "checkpoint_last.pt"), + os.path.join(data_dir, "pretrained.pt"), + ) + config.extend( + [ + "--pretrained", + "True", + "--pretrained-checkpoint", + os.path.join(data_dir, "pretrained.pt"), + "--save-dir", + os.path.join(data_dir, "fusion_model"), + ] + ) + train_translation_model(data_dir, "fconv_self_att_wp", config) + + +class TestLanguageModeling(unittest.TestCase): + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + + def test_fconv_lm(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_fconv_lm") as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + train_language_model( + data_dir, + "fconv_lm", + [ + "--decoder-layers", + "[(850, 3)] * 2 + [(1024,4)]", + "--decoder-embed-dim", + "280", + "--optimizer", + "nag", + "--lr", + "0.1", + ], + ) + eval_lm_main(data_dir) + generate_main( + data_dir, + [ + "--task", + "language_modeling", + "--sample-break-mode", + "eos", + "--tokens-per-sample", + "500", + ], + ) + + def test_transformer_lm(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + train_language_model( + data_dir, + "transformer_lm", + ["--add-bos-token", "--nval", "1"], + run_validation=True, + ) + eval_lm_main(data_dir) + eval_lm_main(data_dir, extra_flags=["--context-window", "25"]) + generate_main( + data_dir, + [ + "--task", + "language_modeling", + "--sample-break-mode", + "eos", + "--tokens-per-sample", + "500", + ], + ) + + def test_normformer_lm(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + train_language_model( + data_dir, + "transformer_lm", + [ + "--add-bos-token", + "--nval", + "1", + "--scale-fc", + "--scale-heads", + "--scale-attn", + "--scale-fc", + ], + run_validation=True, + ) + eval_lm_main(data_dir) + eval_lm_main(data_dir, extra_flags=["--context-window", "25"]) + generate_main( + data_dir, + [ + "--task", + "language_modeling", + "--sample-break-mode", + "eos", + "--tokens-per-sample", + "500", + ], + ) + + def test_transformer_lm_with_adaptive_softmax(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory( + "test_transformer_lm_with_adaptive_softmax" + ) as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + train_language_model( + data_dir, + "transformer_lm", + [ + "--add-bos-token", + "--criterion", + "adaptive_loss", + "--adaptive-softmax-cutoff", + "5,10,15", + ], + run_validation=True, + ) + eval_lm_main(data_dir) + generate_main( + data_dir, + [ + "--task", + "language_modeling", + "--sample-break-mode", + "eos", + "--tokens-per-sample", + "500", + ], + ) + + def test_lightconv_lm(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_lightconv_lm") as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + train_language_model( + data_dir, + "lightconv_lm", + ["--add-bos-token"], + run_validation=True, + ) + eval_lm_main(data_dir) + generate_main( + data_dir, + [ + "--task", + "language_modeling", + "--sample-break-mode", + "eos", + "--tokens-per-sample", + "500", + ], + ) + + def test_lstm_lm(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_lstm_lm") as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + train_language_model( + data_dir, + "lstm_lm", + ["--add-bos-token"], + run_validation=True, + ) + eval_lm_main(data_dir) + generate_main( + data_dir, + [ + "--task", + "language_modeling", + "--sample-break-mode", + "eos", + "--tokens-per-sample", + "500", + ], + ) + + def test_lstm_lm_residuals(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_lstm_lm_residuals") as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + train_language_model( + data_dir, + "lstm_lm", + ["--add-bos-token", "--residuals"], + run_validation=True, + ) + eval_lm_main(data_dir) + generate_main( + data_dir, + [ + "--task", + "language_modeling", + "--sample-break-mode", + "eos", + "--tokens-per-sample", + "500", + ], + ) + + @unittest.skipIf(not has_hf_transformers, "skip test if transformers is missing") + def test_transformer_xl_bptt_lm(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_transformer_xl_bptt_lm") as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + task_flags = [ + "--user-dir", + "examples/truncated_bptt", + "--task", + "truncated_bptt_lm", + "--batch-size", + "2", + "--tokens-per-sample", + "50", + ] + train_language_model( + data_dir=data_dir, + arch="transformer_xl", + extra_flags=task_flags + + [ + "--n-layer", + "2", + ], + task="truncated_bptt_lm", + run_validation=True, + extra_valid_flags=task_flags, + ) + eval_lm_main(data_dir, extra_flags=task_flags) + # Train with activation offloading + train_language_model( + data_dir=data_dir, + arch="transformer_xl", + extra_flags=task_flags + + [ + "--n-layer", + "2", + "--offload-activations", + ], + task="truncated_bptt_lm", + run_validation=True, + extra_valid_flags=task_flags, + ) + + +class TestMaskedLanguageModel(unittest.TestCase): + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + + def test_legacy_masked_lm(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_legacy_mlm") as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + train_legacy_masked_language_model(data_dir, "masked_lm") + + def test_roberta_masked_lm(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_roberta_mlm") as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + train_masked_lm( + data_dir, "roberta_base", extra_flags=["--encoder-layers", "2"] + ) + + def test_roberta_sentence_prediction(self): + num_classes = 3 + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_roberta_head") as data_dir: + create_dummy_roberta_head_data(data_dir, num_classes=num_classes) + preprocess_lm_data(os.path.join(data_dir, "input0")) + preprocess_lm_data(os.path.join(data_dir, "label")) + train_roberta_head(data_dir, "roberta_base", num_classes=num_classes) + + def test_roberta_regression_single(self): + num_classes = 1 + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory( + "test_roberta_regression_single" + ) as data_dir: + create_dummy_roberta_head_data( + data_dir, num_classes=num_classes, regression=True + ) + preprocess_lm_data(os.path.join(data_dir, "input0")) + train_roberta_head( + data_dir, + "roberta_base", + num_classes=num_classes, + extra_flags=["--regression-target"], + ) + + def test_roberta_regression_multiple(self): + num_classes = 3 + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory( + "test_roberta_regression_multiple" + ) as data_dir: + create_dummy_roberta_head_data( + data_dir, num_classes=num_classes, regression=True + ) + preprocess_lm_data(os.path.join(data_dir, "input0")) + train_roberta_head( + data_dir, + "roberta_base", + num_classes=num_classes, + extra_flags=["--regression-target"], + ) + + def test_linformer_roberta_masked_lm(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_linformer_roberta_mlm") as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + train_masked_lm( + data_dir, + "linformer_roberta_base", + extra_flags=[ + "--user-dir", + "examples/linformer/linformer_src", + "--encoder-layers", + "2", + ], + ) + + def test_linformer_roberta_sentence_prediction(self): + num_classes = 3 + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_linformer_roberta_head") as data_dir: + create_dummy_roberta_head_data(data_dir, num_classes=num_classes) + preprocess_lm_data(os.path.join(data_dir, "input0")) + preprocess_lm_data(os.path.join(data_dir, "label")) + train_roberta_head( + data_dir, + "linformer_roberta_base", + num_classes=num_classes, + extra_flags=["--user-dir", "examples/linformer/linformer_src"], + ) + + def test_linformer_roberta_regression_single(self): + num_classes = 1 + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory( + "test_linformer_roberta_regression_single" + ) as data_dir: + create_dummy_roberta_head_data( + data_dir, num_classes=num_classes, regression=True + ) + preprocess_lm_data(os.path.join(data_dir, "input0")) + train_roberta_head( + data_dir, + "linformer_roberta_base", + num_classes=num_classes, + extra_flags=[ + "--regression-target", + "--user-dir", + "examples/linformer/linformer_src", + ], + ) + + def test_linformer_roberta_regression_multiple(self): + num_classes = 3 + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory( + "test_linformer_roberta_regression_multiple" + ) as data_dir: + create_dummy_roberta_head_data( + data_dir, num_classes=num_classes, regression=True + ) + preprocess_lm_data(os.path.join(data_dir, "input0")) + train_roberta_head( + data_dir, + "linformer_roberta_base", + num_classes=num_classes, + extra_flags=[ + "--regression-target", + "--user-dir", + "examples/linformer/linformer_src", + ], + ) + + def _test_pretrained_masked_lm_for_translation(self, learned_pos_emb, encoder_only): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_mlm") as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + train_legacy_masked_language_model( + data_dir, + arch="masked_lm", + extra_args=("--encoder-learned-pos",) if learned_pos_emb else (), + ) + with tempfile.TemporaryDirectory( + "test_mlm_translation" + ) as translation_dir: + create_dummy_data(translation_dir) + preprocess_translation_data( + translation_dir, extra_flags=["--joined-dictionary"] + ) + # Train transformer with data_dir/checkpoint_last.pt + train_translation_model( + translation_dir, + arch="transformer_from_pretrained_xlm", + extra_flags=[ + "--decoder-layers", + "1", + "--decoder-embed-dim", + "32", + "--decoder-attention-heads", + "1", + "--decoder-ffn-embed-dim", + "32", + "--encoder-layers", + "1", + "--encoder-embed-dim", + "32", + "--encoder-attention-heads", + "1", + "--encoder-ffn-embed-dim", + "32", + "--pretrained-xlm-checkpoint", + "{}/checkpoint_last.pt".format(data_dir), + "--activation-fn", + "gelu", + "--max-source-positions", + "500", + "--max-target-positions", + "500", + ] + + ( + ["--encoder-learned-pos", "--decoder-learned-pos"] + if learned_pos_emb + else [] + ) + + (["--init-encoder-only"] if encoder_only else []), + task="translation_from_pretrained_xlm", + ) + + def test_pretrained_masked_lm_for_translation_learned_pos_emb(self): + self._test_pretrained_masked_lm_for_translation(True, False) + + def test_pretrained_masked_lm_for_translation_sinusoidal_pos_emb(self): + self._test_pretrained_masked_lm_for_translation(False, False) + + def test_pretrained_masked_lm_for_translation_encoder_only(self): + self._test_pretrained_masked_lm_for_translation(True, True) + + def test_r4f_roberta(self): + num_classes = 3 + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_r4f_roberta_head") as data_dir: + create_dummy_roberta_head_data(data_dir, num_classes=num_classes) + preprocess_lm_data(os.path.join(data_dir, "input0")) + preprocess_lm_data(os.path.join(data_dir, "label")) + train_roberta_head( + data_dir, + "roberta_base", + num_classes=num_classes, + extra_flags=[ + "--user-dir", + "examples/rxf/rxf_src", + "--criterion", + "sentence_prediction_r3f", + "--spectral-norm-classification-head", + ], + ) + + +def train_legacy_masked_language_model(data_dir, arch, extra_args=()): + train_parser = options.get_training_parser() + # TODO: langs should be in and out right? + train_args = options.parse_args_and_arch( + train_parser, + [ + "--task", + "cross_lingual_lm", + data_dir, + "--arch", + arch, + # Optimizer args + "--optimizer", + "adam", + "--lr-scheduler", + "reduce_lr_on_plateau", + "--lr-shrink", + "0.5", + "--lr", + "0.0001", + "--stop-min-lr", + "1e-09", + # dropout, attention args + "--dropout", + "0.1", + "--attention-dropout", + "0.1", + # MLM args + "--criterion", + "legacy_masked_lm_loss", + "--masked-lm-only", + "--monolingual-langs", + "in,out", + "--num-segment", + "5", + # Transformer args: use a small transformer model for fast training + "--encoder-layers", + "1", + "--encoder-embed-dim", + "32", + "--encoder-attention-heads", + "1", + "--encoder-ffn-embed-dim", + "32", + # Other training args + "--max-tokens", + "500", + "--tokens-per-sample", + "500", + "--save-dir", + data_dir, + "--max-epoch", + "1", + "--no-progress-bar", + "--distributed-world-size", + "1", + "--dataset-impl", + "raw", + "--num-workers", + "0", + ] + + list(extra_args), + ) + train.main(train_args) + + +class TestOptimizers(unittest.TestCase): + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + + def test_optimizers(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_optimizers") as data_dir: + # Use just a bit of data and tiny model to keep this test runtime reasonable + create_dummy_data(data_dir, num_examples=10, maxlen=5) + preprocess_translation_data(data_dir) + optimizers = ["adafactor", "adam", "nag", "adagrad", "sgd", "adadelta"] + last_checkpoint = os.path.join(data_dir, "checkpoint_last.pt") + for optimizer in optimizers: + if os.path.exists(last_checkpoint): + os.remove(last_checkpoint) + train_translation_model( + data_dir, + "lstm", + [ + "--required-batch-size-multiple", + "1", + "--encoder-layers", + "1", + "--encoder-hidden-size", + "32", + "--decoder-layers", + "1", + "--optimizer", + optimizer, + ], + ) + generate_main(data_dir) + + +def read_last_log_entry( + logs: List[logging.LogRecord], logger_name: str +) -> Dict[str, float]: + for x in reversed(logs): + if x.name == logger_name: + return json.loads(x.message) + raise ValueError(f"No entries from {logger_name} found in captured logs") + + +class TestActivationCheckpointing(unittest.TestCase): + base_flags = [ + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--restore-file", + "x.pt", + "--log-format", + "json", + "--log-interval", + "1", + "--max-update", + "2", + ] + + def _train(self, data_dir, extra_flags): + with self.assertLogs() as logs: + train_translation_model( + data_dir, + "transformer_iwslt_de_en", + self.base_flags + extra_flags, + run_validation=True, + extra_valid_flags=["--log-format", "json"], + ) + return logs.records + + def test_activation_offloading_does_not_change_metrics(self): + """Neither ----checkpoint-activations nor --offload-activations should change loss""" + with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir: + + with self.assertLogs(): + create_dummy_data(data_dir, num_examples=20) + preprocess_translation_data(data_dir) + offload_logs = self._train(data_dir, ["--offload-activations"]) + baseline_logs = self._train(data_dir, []) + + assert len(baseline_logs) == len(offload_logs) + + baseline_valid_stats = read_last_log_entry(baseline_logs, "valid") + offload_valid_stats = read_last_log_entry(offload_logs, "valid") + baseline_train_stats = read_last_log_entry(baseline_logs, "train") + offload_train_stats = read_last_log_entry(offload_logs, "train") + + assert ( + baseline_train_stats["train_loss"] == offload_train_stats["train_loss"] + ) + assert ( + baseline_valid_stats["valid_loss"] == offload_valid_stats["valid_loss"] + ) + + def test_activation_checkpointing_does_not_change_metrics(self): + """--checkpoint-activations should not change loss""" + + with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir: + with self.assertLogs(): + create_dummy_data(data_dir, num_examples=20) + preprocess_translation_data(data_dir) + ckpt_logs = self._train(data_dir, ["--checkpoint-activations"]) + baseline_logs = self._train(data_dir, []) + assert len(baseline_logs) == len(ckpt_logs) + + baseline_train_stats = read_last_log_entry(baseline_logs, "train") + ckpt_train_stats = read_last_log_entry(ckpt_logs, "train") + assert baseline_train_stats["train_loss"] == ckpt_train_stats["train_loss"] + + baseline_valid_stats = read_last_log_entry(baseline_logs, "valid") + ckpt_valid_stats = read_last_log_entry(ckpt_logs, "valid") + assert baseline_valid_stats["valid_loss"] == ckpt_valid_stats["valid_loss"] + + +def create_dummy_roberta_head_data( + data_dir, num_examples=100, maxlen=10, num_classes=2, regression=False +): + input_dir = "input0" + + def _create_dummy_data(filename): + random_data = torch.rand(num_examples * maxlen) + input_data = 97 + torch.floor(26 * random_data).int() + if regression: + output_data = torch.rand((num_examples, num_classes)) + else: + output_data = 1 + torch.floor(num_classes * torch.rand(num_examples)).int() + with open(os.path.join(data_dir, input_dir, filename + ".out"), "w") as f_in: + label_filename = filename + ".label" if regression else filename + ".out" + with open(os.path.join(data_dir, "label", label_filename), "w") as f_out: + offset = 0 + for i in range(num_examples): + # write example input + ex_len = random.randint(1, maxlen) + ex_str = " ".join(map(chr, input_data[offset : offset + ex_len])) + print(ex_str, file=f_in) + # write example label + if regression: + class_str = " ".join(map(str, output_data[i].numpy())) + print(class_str, file=f_out) + else: + class_str = "class{}".format(output_data[i]) + print(class_str, file=f_out) + offset += ex_len + + os.mkdir(os.path.join(data_dir, input_dir)) + os.mkdir(os.path.join(data_dir, "label")) + _create_dummy_data("train") + _create_dummy_data("valid") + _create_dummy_data("test") + + +def train_masked_lm(data_dir, arch, extra_flags=None): + train_parser = options.get_training_parser() + train_args = options.parse_args_and_arch( + train_parser, + [ + "--task", + "masked_lm", + data_dir, + "--arch", + arch, + "--optimizer", + "adam", + "--lr", + "0.0001", + "--criterion", + "masked_lm", + "--batch-size", + "500", + "--required-batch-size-multiple", + "1", + "--save-dir", + data_dir, + "--max-epoch", + "1", + "--no-progress-bar", + "--distributed-world-size", + "1", + "--ddp-backend", + "no_c10d", + "--num-workers", + "0", + ] + + (extra_flags or []), + ) + train.main(train_args) + + +def train_roberta_head(data_dir, arch, num_classes=2, extra_flags=None): + train_parser = options.get_training_parser() + train_args = options.parse_args_and_arch( + train_parser, + [ + "--task", + "sentence_prediction", + data_dir, + "--arch", + arch, + "--encoder-layers", + "2", + "--num-classes", + str(num_classes), + "--optimizer", + "adam", + "--lr", + "0.0001", + "--criterion", + "sentence_prediction", + "--max-tokens", + "500", + "--max-positions", + "500", + "--batch-size", + "500", + "--save-dir", + data_dir, + "--max-epoch", + "1", + "--no-progress-bar", + "--distributed-world-size", + "1", + "--ddp-backend", + "no_c10d", + "--num-workers", + "0", + ] + + (extra_flags or []), + ) + train.main(train_args) + + +def eval_lm_main(data_dir, extra_flags=None): + eval_lm_parser = options.get_eval_lm_parser() + eval_lm_args = options.parse_args_and_arch( + eval_lm_parser, + [ + data_dir, + "--path", + os.path.join(data_dir, "checkpoint_last.pt"), + "--no-progress-bar", + "--num-workers", + "0", + ] + + (extra_flags or []), + ) + eval_lm.main(eval_lm_args) + + +if __name__ == "__main__": + unittest.main() diff --git a/fairseq/tests/test_binarizer.py b/fairseq/tests/test_binarizer.py new file mode 100644 index 0000000000000000000000000000000000000000..50075eabcc7e915d9ba90458e5a23c0201727fe2 --- /dev/null +++ b/fairseq/tests/test_binarizer.py @@ -0,0 +1,122 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import typing as tp +import unittest +from tempfile import TemporaryDirectory + +from fairseq.binarizer import BinarizeSummary, FileBinarizer, VocabularyDatasetBinarizer +from fairseq.data import Dictionary, indexed_dataset +from tests.utils import make_data, sizes + + +def build_vocab(data: tp.List[tp.List[str]]) -> Dictionary: + d = Dictionary() + for s in data: + for token in s: + d.add_symbol(token) + d.finalize() + return d + + +class TestBinarizer(unittest.TestCase): + def compare_ds_data(self, summary, data, prefix, impl, vocab): + self.assertEqual(summary.num_seq, len(data)) + self.assertEqual(summary.num_tok, sum([len(s) for s in data])) + + dataset = indexed_dataset.make_dataset(prefix, impl) + + self.assertEqual(len(dataset), len(data)) + decoded = [vocab.string(dataset[i]).split() for i in range(0, len(dataset))] + + self.assertEqual(decoded, data) + data_sizes = [i.item() for i in dataset.sizes] + self.assertEqual(data_sizes, sizes(data)) + + def test_can_binarize_line(self): + data = make_data(length=1) + vocab = build_vocab(data) + + binarizer = VocabularyDatasetBinarizer( + vocab, + ) + + sentence = data[0] + summary = BinarizeSummary() + + tensor = binarizer.binarize_line( + " ".join(sentence), + summary, + ) + + self.assertEqual(len(tensor), len(sentence) + 1) + + self.assertEqual(summary.num_tok, len(sentence) + 1) + self.assertEqual(summary.num_seq, 1) + + def test_can_binarize_file_chunk(self): + # test without multiprocess logic + with TemporaryDirectory() as dirname: + raw_file = os.path.join(dirname, "raw1") + prefix = os.path.join(dirname, "test1") + impl = "mmap" + + data = make_data(out_file=raw_file) + vocab = build_vocab(data) + + binarizer = VocabularyDatasetBinarizer( + vocab, + append_eos=False, + ) + + summary = FileBinarizer._binarize_chunk_and_finalize( + binarizer, + raw_file, + offset_start=0, + offset_end=-1, + output_prefix=prefix, + dataset_impl=impl, + vocab_size=len(vocab), + ) + + self.compare_ds_data(summary, data, prefix, impl, vocab) + + def test_can_multiprocess(self): + with TemporaryDirectory() as dirname: + raw_file = os.path.join(dirname, "raw1") + prefix = os.path.join(dirname, "test1") + impl = "mmap" + data = make_data(out_file=raw_file) + vocab = build_vocab(data) + binarizer = VocabularyDatasetBinarizer( + vocab, + append_eos=False, + ) + # with one worker + summary = FileBinarizer.multiprocess_dataset( + raw_file, + impl, + binarizer, + output_prefix=prefix, + vocab_size=len(vocab), + num_workers=1, + ) + + self.compare_ds_data(summary, data, prefix, impl, vocab) + + # with multiple worker + prefix_multi = os.path.join(dirname, "test2") + summary = FileBinarizer.multiprocess_dataset( + raw_file, + impl, + binarizer, + output_prefix=prefix_multi, + vocab_size=len(vocab), + num_workers=3, + ) + + self.compare_ds_data(summary, data, prefix_multi, impl, vocab) diff --git a/fairseq/tests/test_character_token_embedder.py b/fairseq/tests/test_character_token_embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..24940ebd21a0e4465ca6052409353a3179e9cf6d --- /dev/null +++ b/fairseq/tests/test_character_token_embedder.py @@ -0,0 +1,48 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from fairseq.data import Dictionary +from fairseq.modules import CharacterTokenEmbedder + + +class TestCharacterTokenEmbedder(unittest.TestCase): + def test_character_token_embedder(self): + vocab = Dictionary() + vocab.add_symbol("hello") + vocab.add_symbol("there") + + embedder = CharacterTokenEmbedder( + vocab, [(2, 16), (4, 32), (8, 64), (16, 2)], 64, 5, 2 + ) + + test_sents = [["hello", "unk", "there"], ["there"], ["hello", "there"]] + max_len = max(len(s) for s in test_sents) + input = torch.LongTensor(len(test_sents), max_len + 2).fill_(vocab.pad()) + for i in range(len(test_sents)): + input[i][0] = vocab.eos() + for j in range(len(test_sents[i])): + input[i][j + 1] = vocab.index(test_sents[i][j]) + input[i][j + 2] = vocab.eos() + embs = embedder(input) + + assert embs.size() == (len(test_sents), max_len + 2, 5) + self.assertAlmostEqual(embs[0][0], embs[1][0]) + self.assertAlmostEqual(embs[0][0], embs[0][-1]) + self.assertAlmostEqual(embs[0][1], embs[2][1]) + self.assertAlmostEqual(embs[0][3], embs[1][1]) + + embs.sum().backward() + assert embedder.char_embeddings.weight.grad is not None + + def assertAlmostEqual(self, t1, t2): + self.assertEqual(t1.size(), t2.size(), "size mismatch") + self.assertLess((t1 - t2).abs().max(), 1e-6) + + +if __name__ == "__main__": + unittest.main() diff --git a/fairseq/tests/test_checkpoint_utils.py b/fairseq/tests/test_checkpoint_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f8cd943cfa375b6dec99cfec799006ed400416c9 --- /dev/null +++ b/fairseq/tests/test_checkpoint_utils.py @@ -0,0 +1,125 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import logging +import os +import tempfile +import unittest +from io import StringIO +from unittest.mock import patch + +from fairseq import checkpoint_utils +from tests.utils import ( + create_dummy_data, + preprocess_translation_data, + train_translation_model, +) +import torch + + +class TestCheckpointUtils(unittest.TestCase): + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + + @contextlib.contextmanager + def _train_transformer(self, seed, extra_args=None): + if extra_args is None: + extra_args = [] + with tempfile.TemporaryDirectory(f"_train_transformer_seed{seed}") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model( + data_dir, + "transformer_iwslt_de_en", + [ + "--encoder-layers", + "3", + "--decoder-layers", + "3", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--seed", + str(seed), + ] + + extra_args, + ) + yield os.path.join(data_dir, "checkpoint_last.pt") + + def test_load_model_ensemble_and_task(self): + # with contextlib.redirect_stdout(StringIO()): + with self._train_transformer(seed=123) as model1: + with self._train_transformer(seed=456) as model2: + ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task( + filenames=[model1, model2] + ) + self.assertEqual(len(ensemble), 2) + + # after Transformer has been migrated to Hydra, this will probably + # become cfg.common.seed + self.assertEqual(ensemble[0].args.seed, 123) + self.assertEqual(ensemble[1].args.seed, 456) + + # the task from the first model should be returned + self.assertTrue("seed123" in task.cfg.data) + + # last cfg is saved + self.assertEqual(cfg.common.seed, 456) + + def test_prune_state_dict(self): + with contextlib.redirect_stdout(StringIO()): + extra_args = ["--encoder-layerdrop", "0.01", "--decoder-layerdrop", "0.01"] + with self._train_transformer(seed=1, extra_args=extra_args) as model: + ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task( + filenames=[model], + arg_overrides={ + "encoder_layers_to_keep": "0,2", + "decoder_layers_to_keep": "1", + }, + ) + self.assertEqual(len(ensemble), 1) + self.assertEqual(len(ensemble[0].encoder.layers), 2) + self.assertEqual(len(ensemble[0].decoder.layers), 1) + + def test_torch_persistent_save_async(self): + state_dict = {} + filename = "async_checkpoint.pt" + + with patch(f"{checkpoint_utils.__name__}.PathManager.opena") as mock_opena: + with patch( + f"{checkpoint_utils.__name__}._torch_persistent_save" + ) as mock_save: + checkpoint_utils.torch_persistent_save( + state_dict, filename, async_write=True + ) + mock_opena.assert_called_with(filename, "wb") + mock_save.assert_called() + + def test_load_ema_from_checkpoint(self): + dummy_state = {"a": torch.tensor([1]), "b": torch.tensor([0.1])} + with patch(f"{checkpoint_utils.__name__}.PathManager.open") as mock_open, patch( + f"{checkpoint_utils.__name__}.torch.load" + ) as mock_load: + + mock_load.return_value = {"extra_state": {"ema": dummy_state}} + filename = "ema_checkpoint.pt" + state = checkpoint_utils.load_ema_from_checkpoint(filename) + + mock_open.assert_called_with(filename, "rb") + mock_load.assert_called() + + self.assertIn("a", state["model"]) + self.assertIn("b", state["model"]) + self.assertTrue(torch.allclose(dummy_state["a"], state["model"]["a"])) + self.assertTrue(torch.allclose(dummy_state["b"], state["model"]["b"])) + + +if __name__ == "__main__": + unittest.main() diff --git a/fairseq/tests/test_checkpoint_utils_for_task_level_attributes.py b/fairseq/tests/test_checkpoint_utils_for_task_level_attributes.py new file mode 100644 index 0000000000000000000000000000000000000000..53ab401f03d5bd84df07f97361badccb5f9e86de --- /dev/null +++ b/fairseq/tests/test_checkpoint_utils_for_task_level_attributes.py @@ -0,0 +1,172 @@ +#!/usr/bin/env fbpython +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +import contextlib +import logging +import unittest +from io import StringIO +from unittest.mock import MagicMock, patch + +import torch +from fairseq import checkpoint_utils, data +from omegaconf import OmegaConf + + +def mock_trainer(epoch, num_updates, iterations_in_epoch): + trainer = MagicMock() + trainer.load_checkpoint.return_value = { + "train_iterator": { + "epoch": epoch, + "iterations_in_epoch": iterations_in_epoch, + "shuffle": False, + }, + "FakeTask": checkpoint_dict()["FakeTask"], + } + trainer.get_num_updates.return_value = num_updates + trainer.task.__class__.__name__ = "FakeTask" + trainer.task.get_checkpoint_dict.return_value = checkpoint_dict() + trainer.task.set_checkpoint_dict = MagicMock() + + return trainer + + +def checkpoint_dict(): + return { + "FakeTask": { + "observer_stats": { + ( + 4, + 16, + "MovingAveragePerChannelMinMax", + "MovingAveragePerChannelMinMax", + ): {"mod1": 1, "mod2": 2, "mod3": 3} + } + } + } + + +def mock_dict(): + d = MagicMock() + d.pad.return_value = 1 + d.eos.return_value = 2 + d.unk.return_value = 3 + return d + + +def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch): + tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1) + tokens_ds = data.TokenBlockDataset( + tokens, + sizes=[tokens.size(-1)], + block_size=1, + pad=0, + eos=1, + include_targets=False, + ) + trainer = mock_trainer(epoch, num_updates, iterations_in_epoch) + dataset = data.LanguagePairDataset( + tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False + ) + epoch_itr = data.EpochBatchIterator( + dataset=dataset, + collate_fn=dataset.collater, + batch_sampler=[[i] for i in range(epoch_size)], + ) + return trainer, epoch_itr + + +def get_mock_cfg(finetune_from_model): + cfg_mock = OmegaConf.create( + { + "checkpoint": { + "save_dir": None, + "optimizer_overrides": "{}", + "reset_dataloader": False, + "reset_meters": False, + "reset_optimizer": False, + "reset_lr_scheduler": False, + "finetune_from_model": finetune_from_model, + "model_parallel_size": 1, + "restore_file": "checkpoint_last.pt", + "no_save": False, + "save_interval_updates": 0, + "no_last_checkpoints": False, + "keep_interval_updates": 0, + "keep_last_epochs": 0, + "keep_best_checkpoints": 0, + }, + "common": { + "model_parallel_size": 1, + }, + } + ) + return cfg_mock + + +class TestCheckpointsForTaskLevelAttributes(unittest.TestCase): + def setUp(self) -> None: + self.cfg_mock = get_mock_cfg(None) + self.patches = { + "os.makedirs": MagicMock(), + "os.path.join": MagicMock(), + "os.path.isfile": MagicMock(return_value=True), + "os.path.isabs": MagicMock(return_value=False), + "fairseq.file_io.PathManager.exists": MagicMock(return_value=False), + } + self.applied_patches = [patch(p, d) for p, d in self.patches.items()] + [p.start() for p in self.applied_patches] + logging.disable(logging.CRITICAL) + + self.trainer, self.epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50) + self.trainer.get_train_iterator = MagicMock(return_value=self.epoch_itr) + self.epoch_itr.next_epoch_itr(shuffle=False) + + checkpoint_utils.save_checkpoint( + self.cfg_mock.checkpoint, self.trainer, self.epoch_itr, None + ) + + def tearDown(self): + patch.stopall() + logging.disable(logging.NOTSET) + + def test_verify_checkpoint(self) -> None: + cp_dict = self.trainer.task.get_checkpoint_dict() + self.assertTrue(len(cp_dict) == 1) + self.assertTrue("FakeTask" in cp_dict) + self.assertTrue("observer_stats" in cp_dict["FakeTask"]) + self.assertTrue(len(cp_dict["FakeTask"]["observer_stats"]) == 1) + self.assertTrue( + ( + 4, + 16, + "MovingAveragePerChannelMinMax", + "MovingAveragePerChannelMinMax", + ) + in cp_dict["FakeTask"]["observer_stats"] + ) + self.assertTrue( + cp_dict["FakeTask"]["observer_stats"][ + ( + 4, + 16, + "MovingAveragePerChannelMinMax", + "MovingAveragePerChannelMinMax", + ) + ] + == {"mod1": 1, "mod2": 2, "mod3": 3} + ) + + def test_load_checkpoint(self) -> None: + with contextlib.redirect_stdout(StringIO()): + # Now, load checkpoint to ensure the respective logic works as expected + _, epoch_itr = checkpoint_utils.load_checkpoint( + self.cfg_mock.checkpoint, self.trainer + ) + + self.trainer.task.set_checkpoint_dict.assert_called_once_with( + checkpoint_dict()["FakeTask"] + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/fairseq/tests/test_concat_dataset.py b/fairseq/tests/test_concat_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d94aeffd481a2e107eb5747e41d76435b3f3dc8a --- /dev/null +++ b/fairseq/tests/test_concat_dataset.py @@ -0,0 +1,58 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from fairseq.data import LanguagePairDataset, TokenBlockDataset +from fairseq.data.concat_dataset import ConcatDataset +from tests.test_train import mock_dict + + +class TestConcatDataset(unittest.TestCase): + def setUp(self): + d = mock_dict() + tokens_1 = torch.LongTensor([1]).view(1, -1) + tokens_ds1 = TokenBlockDataset( + tokens_1, + sizes=[tokens_1.size(-1)], + block_size=1, + pad=0, + eos=1, + include_targets=False, + ) + self.dataset_1 = LanguagePairDataset( + tokens_ds1, tokens_ds1.sizes, d, shuffle=False + ) + tokens_2 = torch.LongTensor([2]).view(1, -1) + tokens_ds2 = TokenBlockDataset( + tokens_2, + sizes=[tokens_2.size(-1)], + block_size=1, + pad=0, + eos=1, + include_targets=False, + ) + self.dataset_2 = LanguagePairDataset( + tokens_ds2, tokens_ds2.sizes, d, shuffle=False + ) + + def test_concat_dataset_basics(self): + d = ConcatDataset([self.dataset_1, self.dataset_2]) + assert len(d) == 2 + assert d[0]["source"][0] == 1 + assert d[1]["source"][0] == 2 + + d = ConcatDataset([self.dataset_1, self.dataset_2], sample_ratios=[1, 2]) + assert len(d) == 3 + assert d[0]["source"][0] == 1 + assert d[1]["source"][0] == 2 + assert d[2]["source"][0] == 2 + + d = ConcatDataset([self.dataset_1, self.dataset_2], sample_ratios=[2, 1]) + assert len(d) == 3 + assert d[0]["source"][0] == 1 + assert d[1]["source"][0] == 1 + assert d[2]["source"][0] == 2 diff --git a/fairseq/tests/test_constraints.py b/fairseq/tests/test_constraints.py new file mode 100644 index 0000000000000000000000000000000000000000..d766d5130fb777b1273f061e5e53b36d29bb821b --- /dev/null +++ b/fairseq/tests/test_constraints.py @@ -0,0 +1,275 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import List + +import torch + +from fairseq.token_generation_constraints import ( + ConstraintNode, + OrderedConstraintState, + UnorderedConstraintState, + pack_constraints, +) + + +def tensorize(constraints: List[List[int]]) -> torch.Tensor: + return [torch.tensor(x) for x in constraints] + + +class TestHelperRoutines(unittest.TestCase): + def setUp(self): + self.examples = [ + ([[]], torch.tensor([[0]])), + ([[], []], torch.tensor([[0], [0]])), + ([[torch.tensor([1, 2])], []], torch.tensor([[1, 1, 2, 0], [0, 0, 0, 0]])), + ( + [ + [ + torch.tensor([3, 1, 2]), + torch.tensor([3]), + torch.tensor([4, 5, 6, 7]), + ], + [], + [torch.tensor([1, 8, 9, 10, 1, 4, 11, 12])], + ], + torch.tensor( + [ + [3, 3, 1, 2, 0, 3, 0, 4, 5, 6, 7, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 8, 9, 10, 1, 4, 11, 12, 0, 0, 0], + ] + ), + ), + ] + + def test_packing(self): + """Ensures the list of lists of tensors gets packed correctly.""" + for batch_constraints, expected_tensor in self.examples: + packed = pack_constraints(batch_constraints) + assert torch.equal(packed, expected_tensor) + + +class TestUnorderedConstraintState(unittest.TestCase): + def setUp(self): + # Tuples of (contraint set, expected printed graph, token counts per node) + self.examples = [ + ( + tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), + "([None].False#6 ([1].True#4 ([2].False#1 [3].True#1) [3].True#1 [4].True#1) ([4].False#2 ([5].True#2 ([6].False#1 [7].True#1))))", # noqa + {1: 4, 2: 1, 3: 2, 4: 3, 5: 2, 6: 1, 7: 1}, + ), + ([], "[None].False#0", {}), + (tensorize([[0]]), "([None].False#1 [0].True#1)", {0: 1}), + ( + tensorize([[100000, 1, 2, 3, 4, 5]]), + "([None].False#1 ([100000].False#1 ([1].False#1 ([2].False#1 ([3].False#1 ([4].False#1 [5].True#1))))))", + {100000: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1}, + ), + ( + tensorize([[1, 2], [1, 2]]), + "([None].False#2 ([1].False#2 [2].True#2))", + {1: 2, 2: 2}, + ), + ( + tensorize([[1, 2], [3, 4]]), + "([None].False#2 ([1].False#1 [2].True#1) ([3].False#1 [4].True#1))", + {1: 1, 2: 1, 3: 1, 4: 1}, + ), + ] + + self.sequences = [ + ( + self.examples[0][0], + [], + {"bank": 0, "num_completed": 0, "finished": False, "is_root": True}, + ), + ( + self.examples[0][0], + [1, 2], + {"bank": 2, "num_completed": 0, "finished": False, "is_root": False}, + ), + ( + self.examples[0][0], + [1, 2, 94], + {"bank": 1, "num_completed": 1, "finished": False, "is_root": True}, + ), + ( + self.examples[0][0], + [1, 3, 999, 1, 4], + {"bank": 4, "num_completed": 2, "finished": False, "is_root": False}, + ), + ( + self.examples[0][0], + [1, 3, 999, 1, 4, 999], + {"bank": 4, "num_completed": 2, "finished": False, "is_root": True}, + ), + ( + self.examples[0][0], + [4, 5, 6, 8], + {"bank": 2, "num_completed": 1, "finished": False, "is_root": True}, + ), + ( + self.examples[0][0], + # Tricky, because in last three, goes down [1->4] branch, could miss [1] and [4->5] + # [[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]], + [1, 2, 3, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5], + {"bank": 14, "num_completed": 6, "finished": True, "is_root": False}, + ), + ( + self.examples[0][0], + [1, 2, 3, 999, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5, 117], + {"bank": 14, "num_completed": 6, "finished": True, "is_root": True}, + ), + ( + tensorize([[1], [2, 3]]), + # Should not be able to get credit for entering 1 a second time + [1, 1], + {"bank": 1, "num_completed": 1, "finished": False, "is_root": True}, + ), + ( + self.examples[4][0], + [1, 2, 1, 2], + {"bank": 4, "num_completed": 2, "finished": True, "is_root": False}, + ), + ( + self.examples[4][0], + [1, 2, 1, 2, 1], + {"bank": 4, "num_completed": 2, "finished": True, "is_root": True}, + ), + ( + self.examples[5][0], + [1, 2, 3, 4, 5], + {"bank": 4, "num_completed": 2, "finished": True, "is_root": True}, + ), + ] + + def test_graphs(self): + """ + Test whether unordered graph systems are created correctly. + """ + for example in self.examples: + constraints, expected, gold_counts = example + c = ConstraintNode.create(constraints) + assert ( + ConstraintNode.print_graph(c) == expected + ), f"got {ConstraintNode.print_graph(c)}, expected {expected}" + assert ( + c.token_counts() == gold_counts + ), f"{c} got {c.token_counts()} wanted {gold_counts}" + + def test_next_tokens(self): + """ + Tests that the set of next tokens is correct. + """ + for example in self.examples: + constraints, expected, gold_counts = example + root = ConstraintNode.create(constraints) + + root_tokens = set(root.children.keys()) + for sequence in constraints: + state = UnorderedConstraintState(root) + for token in sequence: + all_tokens = root_tokens.union(state.node.children.keys()) + assert ( + all_tokens == state.next_tokens() + ), f"ALL {all_tokens} NEXT {state.next_tokens()}" + state = state.advance(token) + + def test_sequences(self): + for constraints, tokens, expected in self.sequences: + state = UnorderedConstraintState.create(pack_constraints([constraints])[0]) + for token in tokens: + state = state.advance(token) + result = {} + for attr in expected.keys(): + result[attr] = getattr(state, attr) + + assert ( + result == expected + ), f"TEST({tokens}) GOT: {result} WANTED: {expected}" + + +class TestOrderedConstraintState(unittest.TestCase): + def setUp(self): + self.sequences = [ + ( + tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), + [], + {"bank": 0, "num_completed": 0, "finished": False, "is_root": True}, + ), + ( + tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), + [1, 2], + {"bank": 2, "num_completed": 0, "finished": False, "is_root": False}, + ), + ( + tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), + [1, 2, 94], + {"bank": 0, "num_completed": 0, "finished": False, "is_root": True}, + ), + ( + tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), + [1, 3, 999, 1, 4], + {"bank": 0, "num_completed": 0, "finished": False, "is_root": True}, + ), + ( + tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), + [1, 2, 3, 999, 999], + {"bank": 3, "num_completed": 1, "finished": False, "is_root": False}, + ), + ( + tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), + [1, 2, 3, 77, 1, 3, 1], + {"bank": 6, "num_completed": 2, "finished": False, "is_root": False}, + ), + ( + tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), + [1, 2, 3, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5], + {"bank": 14, "num_completed": 6, "finished": True, "is_root": False}, + ), + ( + tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), + [1, 2, 999, 1, 2, 3, 999, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5, 117], + {"bank": 14, "num_completed": 6, "finished": True, "is_root": False}, + ), + ( + tensorize([[1], [2, 3]]), + [1, 1], + {"bank": 1, "num_completed": 1, "finished": False, "is_root": False}, + ), + ( + tensorize([[1, 2], [1, 2]]), + [1, 2, 1, 2], + {"bank": 4, "num_completed": 2, "finished": True, "is_root": False}, + ), + ( + tensorize([[1, 2], [1, 2]]), + [1, 2, 1, 2, 1], + {"bank": 4, "num_completed": 2, "finished": True, "is_root": False}, + ), + ( + tensorize([[1, 2], [3, 4]]), + [1, 2, 3, 4, 5], + {"bank": 4, "num_completed": 2, "finished": True, "is_root": False}, + ), + ] + + def test_sequences(self): + for i, (constraints, tokens, expected) in enumerate(self.sequences): + state = OrderedConstraintState.create(pack_constraints([constraints])[0]) + for token in tokens: + state = state.advance(token) + result = {} + for attr in expected.keys(): + result[attr] = getattr(state, attr) + assert ( + result == expected + ), f"TEST({tokens}) GOT: {result} WANTED: {expected}" + + +if __name__ == "__main__": + unittest.main() diff --git a/fairseq/tests/test_convtbc.py b/fairseq/tests/test_convtbc.py new file mode 100644 index 0000000000000000000000000000000000000000..3a3c9b91e70f597ab77b9b01459cc429db5d7956 --- /dev/null +++ b/fairseq/tests/test_convtbc.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +import torch.nn as nn +from fairseq.modules import ConvTBC + + +class TestConvTBC(unittest.TestCase): + def test_convtbc(self): + # ksz, in_channels, out_channels + conv_tbc = ConvTBC(4, 5, kernel_size=3, padding=1) + # out_channels, in_channels, ksz + conv1d = nn.Conv1d(4, 5, kernel_size=3, padding=1) + + conv_tbc.weight.data.copy_(conv1d.weight.data.transpose(0, 2)) + conv_tbc.bias.data.copy_(conv1d.bias.data) + + input_tbc = torch.randn(7, 2, 4, requires_grad=True) + input1d = input_tbc.data.transpose(0, 1).transpose(1, 2) + input1d.requires_grad = True + + output_tbc = conv_tbc(input_tbc) + output1d = conv1d(input1d) + + self.assertAlmostEqual( + output_tbc.data.transpose(0, 1).transpose(1, 2), output1d.data + ) + + grad_tbc = torch.randn(output_tbc.size()) + grad1d = grad_tbc.transpose(0, 1).transpose(1, 2).contiguous() + + output_tbc.backward(grad_tbc) + output1d.backward(grad1d) + + self.assertAlmostEqual( + conv_tbc.weight.grad.data.transpose(0, 2), conv1d.weight.grad.data + ) + self.assertAlmostEqual(conv_tbc.bias.grad.data, conv1d.bias.grad.data) + self.assertAlmostEqual( + input_tbc.grad.data.transpose(0, 1).transpose(1, 2), input1d.grad.data + ) + + def assertAlmostEqual(self, t1, t2): + self.assertEqual(t1.size(), t2.size(), "size mismatch") + self.assertLess((t1 - t2).abs().max(), 1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/fairseq/tests/test_data_utils.py b/fairseq/tests/test_data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c48d02c5c6431b01a3a52aa4706d52e249399529 --- /dev/null +++ b/fairseq/tests/test_data_utils.py @@ -0,0 +1,136 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import numpy as np + +from fairseq.data.data_utils_fast import batch_by_size_fn, batch_by_size_vec + + +class TestBatchBySize(unittest.TestCase): + @classmethod + def batch_by_size_baseline( + cls, + indices, + num_tokens_vec, + max_tokens, + max_sentences, + bsz_mult, + ): + """Simple, reliable and slow implementation of batch by size""" + batches = [] + start = 0 + while start < len(indices): + for end in range(start + 1, len(indices) + 1): + max_val = max(num_tokens_vec[pos] for pos in range(start, end)) + sent_count = end - start + num_tokens = max_val * sent_count + overflow = num_tokens > max_tokens > 0 or sent_count > max_sentences > 0 + terminate = overflow or end == len(indices) + if overflow: + sent_count -= 1 + if terminate: + if sent_count > bsz_mult: + sent_count = sent_count - sent_count % bsz_mult + batches.append(indices[start : start + sent_count]) + start = start + sent_count + break + return batches + + @classmethod + def _get_error_message( + cls, max_sentences, max_tokens, bsz_mult, num_tokens_vec, validation, results + ): + return f"""Reference batch_by_size implementation should produce + same output as the baseline method. + Params: + max_sentences={max_sentences}, + max_tokens={max_tokens}, + bsz_mult={bsz_mult}, + num_tokens_vec={num_tokens_vec}, + expected_batches={validation}, + returned_batches={results}""" + + def _compare_results( + self, + indices_len, + batch_by_size_impl, + max_sentences, + max_tokens, + bsz_mult, + num_tokens_vec, + ): + indices = np.array(list(range(indices_len))) + validation = self.batch_by_size_baseline( + indices, + num_tokens_vec, + max_tokens=max_tokens, + max_sentences=max_sentences, + bsz_mult=bsz_mult, + ) + results = batch_by_size_impl( + indices, + num_tokens_vec, + max_tokens=max_tokens, + max_sentences=max_sentences, + bsz_mult=bsz_mult, + ) + error_msg = self._get_error_message( + max_sentences, max_tokens, bsz_mult, num_tokens_vec, validation, results + ) + self.assertEqual(len(validation), len(results), error_msg) + for first, second in zip(validation, results): + self.assertTrue(np.array_equal(first, second), error_msg) + + def _run_compare_with_baseline_sweep(self, batch_by_size_impl): + """Compare reference batch_by_size implementation with batch_by_size_baseline + across a dense grid of hyperparam values""" + MAX_MAX_TOKENS = 10 + NUM_TOKENS_VECS_COUNT = 5 + for indices_len in [10, 11]: # try odd and even len of indices + for max_sentences in range(0, indices_len + 2): + for max_tokens in range(0, MAX_MAX_TOKENS): + for bsz_mult in range(1, max(MAX_MAX_TOKENS, indices_len) + 2): + for _ in range(NUM_TOKENS_VECS_COUNT): + num_tokens_vec = np.random.randint( + 0, max_tokens + 1, size=indices_len + ) + self._compare_results( + indices_len, + batch_by_size_impl, + max_sentences, + max_tokens, + bsz_mult, + num_tokens_vec, + ) + + +class TestBatchBySizeVec(TestBatchBySize): + def test_compare_with_baseline(self): + self._run_compare_with_baseline_sweep(batch_by_size_vec) + + +class TestBatchBySizeFn(TestBatchBySize): + def test_compare_with_baseline(self): + def batch_by_size_fn_wrapper( + indices, + num_tokens_vec, + max_tokens, + max_sentences, + bsz_mult, + ): + def num_tokens_fn(idx): + return num_tokens_vec[idx] + + return batch_by_size_fn( + indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult + ) + + self._run_compare_with_baseline_sweep(batch_by_size_fn_wrapper) + + +if __name__ == "__main__": + unittest.main() diff --git a/fairseq/tests/test_dataclass_utils.py b/fairseq/tests/test_dataclass_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..231f86b6ee6f12fb70aff6e13cde6890eea76917 --- /dev/null +++ b/fairseq/tests/test_dataclass_utils.py @@ -0,0 +1,87 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from argparse import ArgumentParser +from dataclasses import dataclass, field + +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import gen_parser_from_dataclass + + +@dataclass +class A(FairseqDataclass): + data: str = field(default="test", metadata={"help": "the data input"}) + num_layers: int = field(default=200, metadata={"help": "more layers is better?"}) + + +@dataclass +class B(FairseqDataclass): + bar: A = field(default=A()) + foo: int = field(default=0, metadata={"help": "not a bar"}) + + +@dataclass +class D(FairseqDataclass): + arch: A = field(default=A()) + foo: int = field(default=0, metadata={"help": "not a bar"}) + + +@dataclass +class C(FairseqDataclass): + data: str = field(default="test", metadata={"help": "root level data input"}) + encoder: D = field(default=D()) + decoder: A = field(default=A()) + lr: int = field(default=0, metadata={"help": "learning rate"}) + + +class TestDataclassUtils(unittest.TestCase): + def test_argparse_convert_basic(self): + parser = ArgumentParser() + gen_parser_from_dataclass(parser, A(), True) + args = parser.parse_args(["--num-layers", "10", "the/data/path"]) + self.assertEqual(args.num_layers, 10) + self.assertEqual(args.data, "the/data/path") + + def test_argparse_recursive(self): + parser = ArgumentParser() + gen_parser_from_dataclass(parser, B(), True) + args = parser.parse_args(["--num-layers", "10", "--foo", "10", "the/data/path"]) + self.assertEqual(args.num_layers, 10) + self.assertEqual(args.foo, 10) + self.assertEqual(args.data, "the/data/path") + + def test_argparse_recursive_prefixing(self): + self.maxDiff = None + parser = ArgumentParser() + gen_parser_from_dataclass(parser, C(), True, "") + args = parser.parse_args( + [ + "--encoder-arch-data", + "ENCODER_ARCH_DATA", + "--encoder-arch-num-layers", + "10", + "--encoder-foo", + "10", + "--decoder-data", + "DECODER_DATA", + "--decoder-num-layers", + "10", + "--lr", + "10", + "the/data/path", + ] + ) + self.assertEqual(args.encoder_arch_data, "ENCODER_ARCH_DATA") + self.assertEqual(args.encoder_arch_num_layers, 10) + self.assertEqual(args.encoder_foo, 10) + self.assertEqual(args.decoder_data, "DECODER_DATA") + self.assertEqual(args.decoder_num_layers, 10) + self.assertEqual(args.lr, 10) + self.assertEqual(args.data, "the/data/path") + + +if __name__ == "__main__": + unittest.main() diff --git a/fairseq/tests/test_dataset.py b/fairseq/tests/test_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a3e3970028bc4b0259153e403951e1735bb0cd3e --- /dev/null +++ b/fairseq/tests/test_dataset.py @@ -0,0 +1,66 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import unittest +from typing import Sequence + +from fairseq.data import LanguagePairDataset, ListDataset, RoundRobinZipDatasets +from tests.test_train import mock_dict + + +def lang_pair_dataset(lengths: Sequence[int]) -> LanguagePairDataset: + tokens = [[i] * l for i, l in enumerate(lengths)] + return LanguagePairDataset(ListDataset(tokens), lengths, mock_dict()) + + +def sample(id: int, length: int): + return {"id": id, "source": [id] * length, "target": None} + + +class TestDataset(unittest.TestCase): + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + + def test_round_robin_zip_datasets(self): + long_dataset = lang_pair_dataset([10, 9, 8, 11]) + short_dataset = lang_pair_dataset([11, 9]) + + dataset = RoundRobinZipDatasets({"a": long_dataset, "b": short_dataset}) + # Dataset is now sorted by sentence length + dataset.ordered_indices() + assert dataset.longest_dataset is long_dataset + self.assertEqual(dict(dataset[0]), {"a": sample(2, 8), "b": sample(1, 9)}) + # The item 2 of dataset 'a' is with item (2 % 2 = 0) of dataset 'b' + self.assertEqual(dict(dataset[2]), {"a": sample(0, 10), "b": sample(1, 9)}) + + def test_round_robin_zip_datasets_filtered(self): + long_dataset = lang_pair_dataset([10, 20, 8, 11, 1000, 7, 12]) + short_dataset = lang_pair_dataset([11, 20, 9, 1000]) + + dataset = RoundRobinZipDatasets({"a": long_dataset, "b": short_dataset}) + # Dataset is now sorted by sentence length + idx = dataset.ordered_indices() + idx, _ = dataset.filter_indices_by_size(idx, {"a": 19, "b": 900}) + self.assertEqual(list(idx), [0, 1, 2, 3, 4]) + self.assertEqual(dict(dataset[0]), {"a": sample(5, 7), "b": sample(2, 9)}) + self.assertEqual(dict(dataset[2]), {"a": sample(0, 10), "b": sample(1, 20)}) + self.assertEqual(dict(dataset[4]), {"a": sample(6, 12), "b": sample(0, 11)}) + + def test_round_robin_zip_datasets_filtered_with_tuple(self): + long_dataset = lang_pair_dataset([10, 20, 8, 11, 1000, 7, 12]) + short_dataset = lang_pair_dataset([11, 20, 9, 1000]) + + dataset = RoundRobinZipDatasets({"a": long_dataset, "b": short_dataset}) + # Dataset is now sorted by sentence length + idx = dataset.ordered_indices() + idx, _ = dataset.filter_indices_by_size(idx, 19) + self.assertEqual(list(idx), [0, 1, 2, 3, 4]) + self.assertEqual(dict(dataset[0]), {"a": sample(5, 7), "b": sample(2, 9)}) + self.assertEqual(dict(dataset[2]), {"a": sample(0, 10), "b": sample(2, 9)}) + self.assertEqual(dict(dataset[4]), {"a": sample(6, 12), "b": sample(2, 9)}) diff --git a/fairseq/tests/test_dictionary.py b/fairseq/tests/test_dictionary.py new file mode 100644 index 0000000000000000000000000000000000000000..dc9d71b3c722ce3066e182d4b237b2a72999d4d0 --- /dev/null +++ b/fairseq/tests/test_dictionary.py @@ -0,0 +1,145 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import io +import os +import string +import tempfile +import unittest + +import torch +from fairseq import tokenizer +from fairseq.data import Dictionary + + +class TestDictionary(unittest.TestCase): + def test_finalize(self): + txt = [ + "A B C D", + "B C D", + "C D", + "D", + ] + ref_ids1 = list( + map( + torch.IntTensor, + [ + [4, 5, 6, 7, 2], + [5, 6, 7, 2], + [6, 7, 2], + [7, 2], + ], + ) + ) + ref_ids2 = list( + map( + torch.IntTensor, + [ + [7, 6, 5, 4, 2], + [6, 5, 4, 2], + [5, 4, 2], + [4, 2], + ], + ) + ) + + # build dictionary + d = Dictionary() + for line in txt: + d.encode_line(line, add_if_not_exist=True) + + def get_ids(dictionary): + ids = [] + for line in txt: + ids.append(dictionary.encode_line(line, add_if_not_exist=False)) + return ids + + def assertMatch(ids, ref_ids): + for toks, ref_toks in zip(ids, ref_ids): + self.assertEqual(toks.size(), ref_toks.size()) + self.assertEqual(0, (toks != ref_toks).sum().item()) + + ids = get_ids(d) + assertMatch(ids, ref_ids1) + + # check finalized dictionary + d.finalize() + finalized_ids = get_ids(d) + assertMatch(finalized_ids, ref_ids2) + + # write to disk and reload + with tempfile.NamedTemporaryFile(mode="w") as tmp_dict: + d.save(tmp_dict.name) + d = Dictionary.load(tmp_dict.name) + reload_ids = get_ids(d) + assertMatch(reload_ids, ref_ids2) + assertMatch(finalized_ids, reload_ids) + + def test_overwrite(self): + # for example, Camembert overwrites , and + dict_file = io.StringIO( + " 999 #fairseq:overwrite\n" + " 999 #fairseq:overwrite\n" + " 999 #fairseq:overwrite\n" + ", 999\n" + "▁de 999\n" + ) + d = Dictionary() + d.add_from_file(dict_file) + self.assertEqual(d.index(""), 1) + self.assertEqual(d.index("foo"), 3) + self.assertEqual(d.index(""), 4) + self.assertEqual(d.index(""), 5) + self.assertEqual(d.index(""), 6) + self.assertEqual(d.index(","), 7) + self.assertEqual(d.index("▁de"), 8) + + def test_no_overwrite(self): + # for example, Camembert overwrites , and + dict_file = io.StringIO( + " 999\n" " 999\n" " 999\n" ", 999\n" "▁de 999\n" + ) + d = Dictionary() + with self.assertRaisesRegex(RuntimeError, "Duplicate"): + d.add_from_file(dict_file) + + def test_space(self): + # for example, character models treat space as a symbol + dict_file = io.StringIO(" 999\n" "a 999\n" "b 999\n") + d = Dictionary() + d.add_from_file(dict_file) + self.assertEqual(d.index(" "), 4) + self.assertEqual(d.index("a"), 5) + self.assertEqual(d.index("b"), 6) + + def test_add_file_to_dict(self): + counts = {} + num_lines = 100 + per_line = 10 + with tempfile.TemporaryDirectory("test_sampling") as data_dir: + filename = os.path.join(data_dir, "dummy.txt") + with open(filename, "w", encoding="utf-8") as data: + for c in string.ascii_letters: + line = f"{c} " * per_line + for _ in range(num_lines): + data.write(f"{line}\n") + counts[c] = per_line * num_lines + per_line += 5 + + dict = Dictionary() + Dictionary.add_file_to_dictionary( + filename, dict, tokenizer.tokenize_line, 10 + ) + dict.finalize(threshold=0, nwords=-1, padding_factor=8) + + for c in string.ascii_letters: + count = dict.get_count(dict.index(c)) + self.assertEqual( + counts[c], count, f"{c} count is {count} but should be {counts[c]}" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/fairseq/tests/test_ema.py b/fairseq/tests/test_ema.py new file mode 100644 index 0000000000000000000000000000000000000000..bd2cf2c78c6b8d791fc2a81e92c97af7682c052e --- /dev/null +++ b/fairseq/tests/test_ema.py @@ -0,0 +1,275 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from copy import deepcopy +from dataclasses import dataclass +import pytest +from typing import Optional +from unittest.mock import patch + +import torch + +from fairseq.models.ema import EMA + + +class DummyModule(torch.nn.Module): + def __init__(self) -> None: + """LightningModule for testing purposes + + Args: + epoch_min_loss_override (int, optional): Pass in an epoch that will be set to the minimum + validation loss for testing purposes (zero based). If None this is ignored. Defaults to None. + """ + super().__init__() + self.layer = torch.nn.Linear(in_features=32, out_features=2) + self.another_layer = torch.nn.Linear(in_features=2, out_features=2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.layer(x) + return self.another_layer(x) + + +@dataclass +class EMAConfig(object): + ema_decay: float = 0.99 + ema_start_update: int = 0 + ema_fp32: bool = False + ema_seed_model: Optional[str] = None + ema_update_freq: int = 1 + + +class TestEMA(unittest.TestCase): + def assertTorchAllClose(self, x, y, atol=1e-8, rtol=1e-5, msg=None): + diff = x.float() - y.float() + diff_norm = torch.norm(diff) + other_norm = torch.norm(y.float()) + + if msg is None: + msg = "|input - other| > {} + {} * |other|".format(atol, rtol) + + self.assertLessEqual( + diff_norm, + atol + rtol * other_norm, + msg=msg, + ) + + def test_ema(self): + model = DummyModule() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + state = deepcopy(model.state_dict()) + config = EMAConfig() + ema = EMA(model, config) + + # set decay + ema._set_decay(config.ema_decay) + self.assertEqual(ema.get_decay(), config.ema_decay) + + # get model + self.assertEqual(ema.get_model(), ema.model) + + # Since fp32 params is not used, it should be of size 0 + self.assertEqual(len(ema.fp32_params), 0) + + # EMA step + x = torch.randn(32) + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + ema.step(model) + + ema_state_dict = ema.get_model().state_dict() + + for key, param in model.state_dict().items(): + prev_param = state[key] + ema_param = ema_state_dict[key] + + if "version" in key: + # Do not decay a model.version pytorch param + continue + self.assertTorchAllClose( + ema_param, + config.ema_decay * prev_param + (1 - config.ema_decay) * param, + ) + + # Since fp32 params is not used, it should be of size 0 + self.assertEqual(len(ema.fp32_params), 0) + + # Load EMA into model + model2 = DummyModule() + ema.reverse(model2) + + for key, param in model2.state_dict().items(): + ema_param = ema_state_dict[key] + self.assertTrue(torch.allclose(ema_param, param)) + + # Check that step_internal is called once + with patch.object(ema, "_step_internal", return_value=None) as mock_method: + ema.step(model) + mock_method.assert_called_once_with(model, None) + + def _test_ema_start_update(self, updates): + model = DummyModule() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + state = deepcopy(model.state_dict()) + config = EMAConfig(ema_start_update=1) + ema = EMA(model, config) + + # EMA step + x = torch.randn(32) + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + ema.step(model, updates=updates) + ema_state_dict = ema.get_model().state_dict() + + self.assertEqual(ema.get_decay(), 0 if updates == 0 else config.ema_decay) + + for key, param in model.state_dict().items(): + ema_param = ema_state_dict[key] + prev_param = state[key] + + if "version" in key: + # Do not decay a model.version pytorch param + continue + if updates == 0: + self.assertTorchAllClose( + ema_param, + param, + ) + else: + self.assertTorchAllClose( + ema_param, + config.ema_decay * prev_param + (1 - config.ema_decay) * param, + ) + + # Check that step_internal is called once + with patch.object(ema, "_step_internal", return_value=None) as mock_method: + ema.step(model, updates=updates) + mock_method.assert_called_once_with(model, updates) + + def test_ema_before_start_update(self): + self._test_ema_start_update(updates=0) + + def test_ema_after_start_update(self): + self._test_ema_start_update(updates=1) + + def test_ema_fp32(self): + dtype = torch.float + + model = DummyModule().to(dtype) + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + state = deepcopy(model.state_dict()) + config = EMAConfig(ema_fp32=True) + ema = EMA(model, config) + + x = torch.randn(32) + y = model(x.to(dtype)) + loss = y.sum() + loss.backward() + optimizer.step() + + ema.step(model) + + for key, param in model.state_dict().items(): + prev_param = state[key] + ema_param = ema.get_model().state_dict()[key] + + if "version" in key: + # Do not decay a model.version pytorch param + continue + self.assertIn(key, ema.fp32_params) + + # EMA update is done in fp32, and hence the EMA param must be + # closer to the EMA update done in fp32 than in fp16. + self.assertLessEqual( + torch.norm( + ema_param.float() + - ( + config.ema_decay * prev_param.float() + + (1 - config.ema_decay) * param.float() + ) + .to(dtype) + .float() + ), + torch.norm( + ema_param.float() + - ( + config.ema_decay * prev_param + (1 - config.ema_decay) * param + ).float() + ), + ) + self.assertTorchAllClose( + ema_param, + ( + config.ema_decay * prev_param.float() + + (1 - config.ema_decay) * param.float() + ).to(dtype), + ) + + @pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CPU no longer supports Linear in half precision", + ) + def test_ema_fp16(self): + model = DummyModule().cuda().half() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + state = deepcopy(model.state_dict()) + config = EMAConfig(ema_fp32=False) + ema = EMA(model, config) + + # Since fp32 params is not used, it should be of size 0 + self.assertEqual(len(ema.fp32_params), 0) + + x = torch.randn(32).cuda() + y = model(x.half()) + loss = y.sum() + loss.backward() + optimizer.step() + + ema.step(model) + + for key, param in model.state_dict().items(): + prev_param = state[key] + ema_param = ema.get_model().state_dict()[key] + + if "version" in key: + # Do not decay a model.version pytorch param + continue + + # EMA update is done in fp16, and hence the EMA param must be + # closer to the EMA update done in fp16 than in fp32. + self.assertLessEqual( + torch.norm( + ema_param.float() + - ( + config.ema_decay * prev_param + (1 - config.ema_decay) * param + ).float() + ), + torch.norm( + ema_param.float() + - ( + config.ema_decay * prev_param.float() + + (1 - config.ema_decay) * param.float() + ) + .half() + .float() + ), + ) + self.assertTorchAllClose( + ema_param, + config.ema_decay * prev_param + (1 - config.ema_decay) * param, + ) + + # Since fp32 params is not used, it should be of size 0 + self.assertEqual(len(ema.fp32_params), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/fairseq/tests/test_espnet_multihead_attention.py b/fairseq/tests/test_espnet_multihead_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..ee71dd0e984520daba0dcd2511de0da338e4924c --- /dev/null +++ b/fairseq/tests/test_espnet_multihead_attention.py @@ -0,0 +1,176 @@ +import torch +import numpy as np +import unittest +from fairseq.modules import ( + ESPNETMultiHeadedAttention, + RelPositionMultiHeadedAttention, + RotaryPositionMultiHeadedAttention, +) + +torch.use_deterministic_algorithms(True) + + +class TestESPNETMultiHeadedAttention(unittest.TestCase): + def setUp(self) -> None: + self.T = 3 + self.B = 1 + self.C = 2 + torch.manual_seed(0) + self.sample = torch.randn(self.T, self.B, self.C) # TBC + self.sample_scores = torch.randn(self.B, 1, self.T, self.T) + self.MHA = ESPNETMultiHeadedAttention(self.C, 1, dropout=0) + + def test_forward(self): + expected_scores = torch.tensor( + [[[0.1713, -0.3776]], [[0.2263, -0.4486]], [[0.2243, -0.4538]]] + ) + scores, _ = self.MHA(self.sample, self.sample, self.sample) + self.assertTrue( + np.allclose( + expected_scores.cpu().detach().numpy(), + scores.cpu().detach().numpy(), + atol=1e-4, + ) + ) + + def test_forward_qkv(self): + expected_query = torch.tensor( + [[[[-1.0235, 0.0409], [0.4008, 1.3077], [0.5396, 2.0698]]]] + ) + expected_key = torch.tensor( + [[[[0.5053, -0.4965], [-0.3730, -0.9473], [-0.7019, -0.1935]]]] + ) + expected_val = torch.tensor( + [[[[-0.9940, 0.5403], [0.5924, -0.7619], [0.7504, -1.0892]]]] + ) + sample_t = self.sample.transpose(0, 1) + query, key, val = self.MHA.forward_qkv(sample_t, sample_t, sample_t) + self.assertTrue( + np.allclose( + expected_query.cpu().detach().numpy(), + query.cpu().detach().numpy(), + atol=1e-4, + ) + ) + self.assertTrue( + np.allclose( + expected_key.cpu().detach().numpy(), + key.cpu().detach().numpy(), + atol=1e-4, + ) + ) + self.assertTrue( + np.allclose( + expected_val.cpu().detach().numpy(), + val.cpu().detach().numpy(), + atol=1e-4, + ) + ) + + def test_forward_attention(self): + expected_scores = torch.tensor( + [[[0.1627, -0.6249], [-0.2547, -0.6487], [-0.0711, -0.8545]]] + ) + scores = self.MHA.forward_attention( + self.sample.transpose(0, 1).view(self.B, 1, self.T, self.C), + self.sample_scores, + mask=None, + ) + self.assertTrue( + np.allclose( + expected_scores.cpu().detach().numpy(), + scores.cpu().detach().numpy(), + atol=1e-4, + ) + ) + + +class TestRelPositionMultiHeadedAttention(unittest.TestCase): + def setUp(self) -> None: + self.T = 3 + self.B = 1 + self.C = 2 + torch.manual_seed(0) + self.sample = torch.randn(self.T, self.B, self.C) # TBC + self.sample_x = torch.randn(self.B, 1, self.T, self.T * 2 - 1) + self.sample_pos = torch.randn(self.B, self.T * 2 - 1, self.C) + self.MHA = RelPositionMultiHeadedAttention(self.C, 1, dropout=0) + + def test_rel_shift(self): + expected_x = torch.tensor( + [ + [ + [ + [-0.7193, -0.4033, -0.5966], + [-0.8567, 1.1006, -1.0712], + [-0.5663, 0.3731, -0.8920], + ] + ] + ] + ) + x = self.MHA.rel_shift(self.sample_x) + self.assertTrue( + np.allclose( + expected_x.cpu().detach().numpy(), + x.cpu().detach().numpy(), + atol=1e-4, + ) + ) + + def test_forward(self): + expected_scores = torch.tensor( + [ + [[-0.9609, -0.5020]], + [[-0.9308, -0.4890]], + [[-0.9473, -0.4948]], + [[-0.9609, -0.5020]], + [[-0.9308, -0.4890]], + [[-0.9473, -0.4948]], + [[-0.9609, -0.5020]], + [[-0.9308, -0.4890]], + [[-0.9473, -0.4948]], + [[-0.9609, -0.5020]], + [[-0.9308, -0.4890]], + [[-0.9473, -0.4948]], + [[-0.9609, -0.5020]], + [[-0.9308, -0.4890]], + [[-0.9473, -0.4948]], + ] + ) + scores, _ = self.MHA(self.sample, self.sample, self.sample, self.sample_pos) + self.assertTrue( + np.allclose( + expected_scores.cpu().detach().numpy(), + scores.cpu().detach().numpy(), + atol=1e-4, + ) + ) + + +class TestRotaryPositionMultiHeadedAttention(unittest.TestCase): + def setUp(self) -> None: + self.T = 3 + self.B = 1 + self.C = 2 + torch.manual_seed(0) + self.sample = torch.randn(self.T, self.B, self.C) # TBC + self.MHA = RotaryPositionMultiHeadedAttention( + self.C, 1, dropout=0, precision=None + ) + + def test_forward(self): + expected_scores = torch.tensor( + [[[-0.3220, -0.4726]], [[-1.2813, -0.0979]], [[-0.3138, -0.4758]]] + ) + scores, _ = self.MHA(self.sample, self.sample, self.sample) + self.assertTrue( + np.allclose( + expected_scores.cpu().detach().numpy(), + scores.cpu().detach().numpy(), + atol=1e-4, + ) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/fairseq/tests/test_export.py b/fairseq/tests/test_export.py new file mode 100644 index 0000000000000000000000000000000000000000..3e9a48d187930fc2191d184190ad86fb4951e8bb --- /dev/null +++ b/fairseq/tests/test_export.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import tempfile +import unittest + +import torch + +from fairseq.data.dictionary import Dictionary +from fairseq.models.transformer import TransformerModel +from fairseq.modules import multihead_attention, sinusoidal_positional_embedding +from fairseq.tasks.fairseq_task import LegacyFairseqTask + +DEFAULT_TEST_VOCAB_SIZE = 100 + + +class DummyTask(LegacyFairseqTask): + def __init__(self, args): + super().__init__(args) + self.dictionary = get_dummy_dictionary() + if getattr(self.args, "ctc", False): + self.dictionary.add_symbol("") + self.src_dict = self.dictionary + self.tgt_dict = self.dictionary + + @property + def source_dictionary(self): + return self.src_dict + + @property + def target_dictionary(self): + return self.dictionary + + +def get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE): + dummy_dict = Dictionary() + # add dummy symbol to satisfy vocab size + for id, _ in enumerate(range(vocab_size)): + dummy_dict.add_symbol("{}".format(id), 1000) + return dummy_dict + + +def get_dummy_task_and_parser(): + """ + Return a dummy task and argument parser, which can be used to + create a model/criterion. + """ + parser = argparse.ArgumentParser( + description="test_dummy_s2s_task", argument_default=argparse.SUPPRESS + ) + DummyTask.add_args(parser) + args = parser.parse_args([]) + task = DummyTask.setup_task(args) + return task, parser + + +def _test_save_and_load(scripted_module): + with tempfile.NamedTemporaryFile() as f: + scripted_module.save(f.name) + torch.jit.load(f.name) + + +class TestExportModels(unittest.TestCase): + def test_export_multihead_attention(self): + module = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2) + scripted = torch.jit.script(module) + _test_save_and_load(scripted) + + def test_incremental_state_multihead_attention(self): + module1 = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2) + module1 = torch.jit.script(module1) + module2 = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2) + module2 = torch.jit.script(module2) + + state = {} + state = module1.set_incremental_state(state, "key", {"a": torch.tensor([1])}) + state = module2.set_incremental_state(state, "key", {"a": torch.tensor([2])}) + v1 = module1.get_incremental_state(state, "key")["a"] + v2 = module2.get_incremental_state(state, "key")["a"] + + self.assertEqual(v1, 1) + self.assertEqual(v2, 2) + + def test_positional_embedding(self): + module = sinusoidal_positional_embedding.SinusoidalPositionalEmbedding( + embedding_dim=8, padding_idx=1 + ) + scripted = torch.jit.script(module) + _test_save_and_load(scripted) + + @unittest.skipIf( + torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release" + ) + def test_export_transformer(self): + task, parser = get_dummy_task_and_parser() + TransformerModel.add_args(parser) + args = parser.parse_args([]) + model = TransformerModel.build_model(args, task) + scripted = torch.jit.script(model) + _test_save_and_load(scripted) + + @unittest.skipIf( + torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release" + ) + def test_export_transformer_no_token_pos_emb(self): + task, parser = get_dummy_task_and_parser() + TransformerModel.add_args(parser) + args = parser.parse_args([]) + args.no_token_positional_embeddings = True + model = TransformerModel.build_model(args, task) + scripted = torch.jit.script(model) + _test_save_and_load(scripted) + + +if __name__ == "__main__": + unittest.main() diff --git a/fairseq/tests/test_file_chunker_utils.py b/fairseq/tests/test_file_chunker_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5cded04572f0ab68c81db9ad14de1c18951a1a10 --- /dev/null +++ b/fairseq/tests/test_file_chunker_utils.py @@ -0,0 +1,63 @@ +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import shutil +import tempfile +import unittest +from typing import Optional + + +class TestFileChunker(unittest.TestCase): + _tmpdir: Optional[str] = None + _tmpfile: Optional[str] = None + _line_content = "Hello, World\n" + _num_bytes = None + _num_lines = 200 + _num_splits = 20 + + @classmethod + def setUpClass(cls) -> None: + cls._num_bytes = len(cls._line_content.encode("utf-8")) + cls._tmpdir = tempfile.mkdtemp() + with open(os.path.join(cls._tmpdir, "test.txt"), "w") as f: + cls._tmpfile = f.name + for _i in range(cls._num_lines): + f.write(cls._line_content) + f.flush() + + @classmethod + def tearDownClass(cls) -> None: + # Cleanup temp working dir. + if cls._tmpdir is not None: + shutil.rmtree(cls._tmpdir) # type: ignore + + def test_find_offsets(self): + from fairseq.file_chunker_utils import find_offsets + + offsets = find_offsets(self._tmpfile, self._num_splits) + self.assertEqual(len(offsets), self._num_splits + 1) + (zero, *real_offsets, last) = offsets + self.assertEqual(zero, 0) + for i, o in enumerate(real_offsets): + self.assertEqual( + o, + self._num_bytes + + ((i + 1) * self._num_bytes * self._num_lines / self._num_splits), + ) + self.assertEqual(last, self._num_bytes * self._num_lines) + + def test_readchunks(self): + from fairseq.file_chunker_utils import Chunker, find_offsets + + offsets = find_offsets(self._tmpfile, self._num_splits) + for start, end in zip(offsets, offsets[1:]): + with Chunker(self._tmpfile, start, end) as lines: + all_lines = list(lines) + num_lines = self._num_lines / self._num_splits + self.assertAlmostEqual( + len(all_lines), num_lines, delta=1 + ) # because we split on the bites, we might end up with one more/less line in a chunk + self.assertListEqual( + all_lines, [self._line_content for _ in range(len(all_lines))] + ) diff --git a/fairseq/tests/test_file_io.py b/fairseq/tests/test_file_io.py new file mode 100644 index 0000000000000000000000000000000000000000..af7c4cedb8b7f24b69905d768c9ba5cd75d48cfd --- /dev/null +++ b/fairseq/tests/test_file_io.py @@ -0,0 +1,59 @@ +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import shutil +import sys +import tempfile +import unittest +from typing import Optional +from unittest.mock import MagicMock + + +class TestFileIO(unittest.TestCase): + + _tmpdir: Optional[str] = None + _tmpfile: Optional[str] = None + _tmpfile_contents = "Hello, World" + + @classmethod + def setUpClass(cls) -> None: + cls._tmpdir = tempfile.mkdtemp() + with open(os.path.join(cls._tmpdir, "test.txt"), "w") as f: + cls._tmpfile = f.name + f.write(cls._tmpfile_contents) + f.flush() + + @classmethod + def tearDownClass(cls) -> None: + # Cleanup temp working dir. + if cls._tmpdir is not None: + shutil.rmtree(cls._tmpdir) # type: ignore + + def test_file_io(self): + from fairseq.file_io import PathManager + + with PathManager.open(os.path.join(self._tmpdir, "test.txt"), "r") as f: + s = f.read() + self.assertEqual(s, self._tmpfile_contents) + + def test_file_io_oss(self): + # Mock iopath to simulate oss environment. + sys.modules["iopath"] = MagicMock() + from fairseq.file_io import PathManager + + with PathManager.open(os.path.join(self._tmpdir, "test.txt"), "r") as f: + s = f.read() + self.assertEqual(s, self._tmpfile_contents) + + def test_file_io_async(self): + # ioPath `PathManager` is initialized after the first `opena` call. + try: + from fairseq.file_io import PathManager + + _asyncfile = os.path.join(self._tmpdir, "async.txt") + f = PathManager.opena(_asyncfile, "wb") + f.close() + + finally: + self.assertTrue(PathManager.async_close()) diff --git a/fairseq/tests/test_fp16_optimizer.py b/fairseq/tests/test_fp16_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..27085a12da615e5fa287b45987562410ca4d80b2 --- /dev/null +++ b/fairseq/tests/test_fp16_optimizer.py @@ -0,0 +1,111 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import logging +import unittest + +import torch +from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer +from omegaconf import OmegaConf + + +@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") +class TestGradientScaling(unittest.TestCase): + def setUp(self): + self.x = torch.tensor([2.0]).cuda().half() + weight = 3.0 + bias = 5.0 + self.error = 1.0 + self.target = torch.tensor([self.x * weight + bias + self.error]).cuda().half() + self.loss_fn = torch.nn.L1Loss() + + self.model = torch.nn.Linear(1, 1) + self.model.weight.data = torch.tensor([[weight]]) + self.model.bias.data = torch.tensor([bias]) + self.model.cuda().half() + self.params = list(self.model.parameters()) + + self.cfg_dls = OmegaConf.create( + { + "optimization": { + "lr": [0.1], + }, + "optimizer": { + "_name": "adam", + "lr": [0.1], + "adam_betas": "(0.9, 0.999)", + "adam_eps": 1e-8, + "weight_decay": 0.0, + }, + "common": { + "fp16_init_scale": 1, + "fp16_scale_window": 1, + "fp16_scale_tolerance": 1, + "threshold_loss_scale": 1, + "min_loss_scale": 1e-4, + "tpu": False, + }, + } + ) + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + + def run_iter(self, model, params, optimizer): + optimizer.zero_grad() + y = model(self.x) + loss = self.loss_fn(y, self.target) + optimizer.backward(loss) + self.assertEqual(loss, torch.tensor(1.0, device="cuda:0", dtype=torch.float16)) + + grad_norm = optimizer.clip_grad_norm(0) + self.assertAlmostEqual(grad_norm.item(), 2.2361, 4) + + optimizer.step() + self.assertEqual( + model.weight, + torch.tensor( + [[3.0996]], device="cuda:0", dtype=torch.float16, requires_grad=True + ), + ) + self.assertEqual( + model.bias, + torch.tensor( + [5.1016], device="cuda:0", dtype=torch.float16, requires_grad=True + ), + ) + self.assertEqual(optimizer.scaler.loss_scale, 2.0) + + def test_mixed_precision(self): + model = copy.deepcopy(self.model) + params = list(model.parameters()) + optimizer = FP16Optimizer.build_optimizer(self.cfg_dls, params) + + self.run_iter(model, params, optimizer) + self.assertTrue( + all( + torch.all( + fp32_params.eq( + torch.tensor( + [3.1000, 5.1000], device="cuda:0", requires_grad=True + ) + ) + ) + for fp32_params in optimizer.fp32_params.values() + ) + ) + + def test_memory_efficient(self): + model = copy.deepcopy(self.model) + params = list(model.parameters()) + optimizer = MemoryEfficientFP16Optimizer.build_optimizer(self.cfg_dls, params) + + self.run_iter(model, params, optimizer) + + +if __name__ == "__main__": + unittest.main() diff --git a/fairseq/tests/test_hf_hub.py b/fairseq/tests/test_hf_hub.py new file mode 100644 index 0000000000000000000000000000000000000000..5cfef70d061049602cc47292ede2a0d7aa105608 --- /dev/null +++ b/fairseq/tests/test_hf_hub.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch + +try: + import huggingface_hub +except ImportError: + huggingface_hub = None + +from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub + + +@unittest.skipIf(not huggingface_hub, "Requires huggingface_hub install") +class TestHuggingFaceHub(unittest.TestCase): + @torch.no_grad() + def test_hf_fastspeech2(self): + hf_model_id = "facebook/fastspeech2-en-ljspeech" + models, cfg, task = load_model_ensemble_and_task_from_hf_hub(hf_model_id) + self.assertTrue(len(models) > 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/fairseq/tests/test_huffman.py b/fairseq/tests/test_huffman.py new file mode 100644 index 0000000000000000000000000000000000000000..85d0c72a7668a752caa57f107566b2f66480f8c3 --- /dev/null +++ b/fairseq/tests/test_huffman.py @@ -0,0 +1,179 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import typing as tp +import unittest +from collections import Counter +from tempfile import NamedTemporaryFile, TemporaryDirectory + +from fairseq.data import Dictionary, indexed_dataset +from fairseq.data.huffman import ( + HuffmanCodeBuilder, + HuffmanCoder, + HuffmanMMapIndexedDataset, + HuffmanMMapIndexedDatasetBuilder, +) +from tests.utils import POPULATION, make_data, sizes + + +def make_counts(data: tp.List[tp.List[str]]) -> Counter: + return Counter([symbol for sentence in data for symbol in sentence]) + + +def make_code_builder(data: tp.List[tp.List[str]]) -> HuffmanCodeBuilder: + builder = HuffmanCodeBuilder() + for sentence in data: + builder.add_symbols(*sentence) + return builder + + +class TestCodeBuilder(unittest.TestCase): + def test_code_builder_can_count(self): + data = make_data() + counts = make_counts(data) + builder = make_code_builder(data) + + self.assertEqual(builder.symbols, counts) + + def test_code_builder_can_add(self): + data = make_data() + counts = make_counts(data) + builder = make_code_builder(data) + + new_builder = builder + builder + + self.assertEqual(new_builder.symbols, counts + counts) + + def test_code_builder_can_io(self): + data = make_data() + builder = make_code_builder(data) + + with NamedTemporaryFile() as tmp_fp: + builder.to_file(tmp_fp.name) + other_builder = HuffmanCodeBuilder.from_file(tmp_fp.name) + + self.assertEqual(builder.symbols, other_builder.symbols) + + +class TestCoder(unittest.TestCase): + def test_coder_can_io(self): + data = make_data() + builder = make_code_builder(data) + coder = builder.build_code() + + with NamedTemporaryFile() as tmp_fp: + coder.to_file(tmp_fp.name) + other_coder = HuffmanCoder.from_file(tmp_fp.name) + + self.assertEqual(coder, other_coder) + + def test_coder_can_encode_decode(self): + data = make_data() + builder = make_code_builder(data) + coder = builder.build_code() + + encoded = [coder.encode(sentence) for sentence in data] + decoded = [[n.symbol for n in coder.decode(enc)] for enc in encoded] + + self.assertEqual(decoded, data) + + unseen_data = make_data() + unseen_encoded = [coder.encode(sentence) for sentence in unseen_data] + unseen_decoded = [ + [n.symbol for n in coder.decode(enc)] for enc in unseen_encoded + ] + self.assertEqual(unseen_decoded, unseen_data) + + +def build_dataset(prefix, data, coder): + with HuffmanMMapIndexedDatasetBuilder(prefix, coder) as builder: + for sentence in data: + builder.add_item(sentence) + + +class TestHuffmanDataset(unittest.TestCase): + def test_huffman_can_encode_decode(self): + data = make_data() + builder = make_code_builder(data) + coder = builder.build_code() + + with TemporaryDirectory() as dirname: + prefix = os.path.join(dirname, "test1") + build_dataset(prefix, data, coder) + dataset = HuffmanMMapIndexedDataset(prefix) + + self.assertEqual(len(dataset), len(data)) + decoded = [list(dataset.get_symbols(i)) for i in range(0, len(dataset))] + + self.assertEqual(decoded, data) + data_sizes = [i.item() for i in dataset.sizes] + self.assertEqual(data_sizes, sizes(data)) + + def test_huffman_compresses(self): + data = make_data() + builder = make_code_builder(data) + coder = builder.build_code() + + with TemporaryDirectory() as dirname: + prefix = os.path.join(dirname, "huffman") + build_dataset(prefix, data, coder) + + prefix_mmap = os.path.join(dirname, "mmap") + mmap_builder = indexed_dataset.make_builder( + indexed_dataset.data_file_path(prefix_mmap), + "mmap", + vocab_size=len(POPULATION), + ) + dictionary = Dictionary() + for c in POPULATION: + dictionary.add_symbol(c) + dictionary.finalize() + for sentence in data: + mmap_builder.add_item(dictionary.encode_line(" ".join(sentence))) + mmap_builder.finalize(indexed_dataset.index_file_path(prefix_mmap)) + + huff_size = os.stat(indexed_dataset.data_file_path(prefix)).st_size + mmap_size = os.stat(indexed_dataset.data_file_path(prefix_mmap)).st_size + self.assertLess(huff_size, mmap_size) + + def test_huffman_can_append(self): + data1 = make_data() + builder = make_code_builder(data1) + coder = builder.build_code() + + with TemporaryDirectory() as dirname: + prefix1 = os.path.join(dirname, "test1") + build_dataset(prefix1, data1, coder) + + data2 = make_data() + prefix2 = os.path.join(dirname, "test2") + build_dataset(prefix2, data2, coder) + + prefix3 = os.path.join(dirname, "test3") + + with HuffmanMMapIndexedDatasetBuilder(prefix3, coder) as builder: + builder.append(prefix1) + builder.append(prefix2) + + dataset = HuffmanMMapIndexedDataset(prefix3) + + self.assertEqual(len(dataset), len(data1) + len(data2)) + + decoded1 = [list(dataset.get_symbols(i)) for i in range(0, len(data1))] + self.assertEqual(decoded1, data1) + + decoded2 = [ + list(dataset.get_symbols(i)) for i in range(len(data1), len(dataset)) + ] + self.assertEqual(decoded2, data2) + + data_sizes = [i.item() for i in dataset.sizes] + self.assertEqual(data_sizes[: len(data1)], sizes(data1)) + self.assertEqual(data_sizes[len(data1) : len(dataset)], sizes(data2)) + + +if __name__ == "__main__": + unittest.main() diff --git a/fairseq/tests/test_inference_dropout.py b/fairseq/tests/test_inference_dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..353ac674780a9795492c75aa0a7bc0677b07a9c9 --- /dev/null +++ b/fairseq/tests/test_inference_dropout.py @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import unittest + +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.models.transformer import TransformerModel +from tests.test_sequence_generator import get_dummy_task_and_parser + + +class TestInferenceDropout(unittest.TestCase): + def setUp(self): + self.task, self.parser = get_dummy_task_and_parser() + TransformerModel.add_args(self.parser) + self.args = self.parser.parse_args([]) + self.args.encoder_layers = 2 + self.args.decoder_layers = 1 + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + + def test_sets_inference_dropout_to_true(self): + self.args.retain_dropout = True + self.transformer_model = TransformerModel.build_model(self.args, self.task) + cfg = convert_namespace_to_omegaconf(self.args) + self.transformer_model.prepare_for_inference_(cfg) + assert self.transformer_model.encoder.dropout_module.apply_during_inference + assert self.transformer_model.decoder.dropout_module.apply_during_inference + for layer in self.transformer_model.encoder.layers: + assert layer.dropout_module.apply_during_inference + + def test_inference_dropout_false_by_default(self): + self.transformer_model = TransformerModel.build_model(self.args, self.task) + cfg = convert_namespace_to_omegaconf(self.args) + self.transformer_model.prepare_for_inference_(cfg) + assert not self.transformer_model.encoder.dropout_module.apply_during_inference + assert not self.transformer_model.decoder.dropout_module.apply_during_inference + for layer in self.transformer_model.encoder.layers: + assert not layer.dropout_module.apply_during_inference + for layer in self.transformer_model.decoder.layers: + assert not layer.dropout_module.apply_during_inference + + def test_applies_training_mode(self): + self.transformer_model = TransformerModel.build_model(self.args, self.task) + assert self.transformer_model.encoder.dropout_module.training + for layer in self.transformer_model.encoder.layers: + assert layer.dropout_module.training + + self.transformer_model.eval() + assert not self.transformer_model.decoder.dropout_module.training + for layer in self.transformer_model.encoder.layers: + assert not layer.dropout_module.training + + def test_retain_modules(self): + self.args.retain_dropout = True + self.args.retain_dropout_modules = [ + "TransformerEncoder", + "TransformerEncoderLayer", + ] + self.transformer_model = TransformerModel.build_model(self.args, self.task) + cfg = convert_namespace_to_omegaconf(self.args) + self.transformer_model.prepare_for_inference_(cfg) + assert self.transformer_model.encoder.dropout_module.apply_during_inference + assert not self.transformer_model.decoder.dropout_module.apply_during_inference + for layer in self.transformer_model.decoder.layers: + assert not layer.dropout_module.apply_during_inference diff --git a/fairseq/tests/test_iopath.py b/fairseq/tests/test_iopath.py new file mode 100644 index 0000000000000000000000000000000000000000..48230a6379a2253b4f8dd6210762c4cb8221b91d --- /dev/null +++ b/fairseq/tests/test_iopath.py @@ -0,0 +1,28 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from unittest import mock + + +class TestIOPath(unittest.TestCase): + def test_no_iopath(self): + from .test_reproducibility import TestReproducibility + + with mock.patch.dict("sys.modules", {"iopath": None}): + # reuse reproducibility tests, which are e2e tests that should cover + # most checkpoint related functionality + TestReproducibility._test_reproducibility(self, "test_reproducibility") + + def test_no_supports_rename(self): + from .test_reproducibility import TestReproducibility + + with mock.patch("fairseq.file_io.PathManager.supports_rename") as mock_fn: + mock_fn.return_value = False + TestReproducibility._test_reproducibility(self, "test_reproducibility") + + +if __name__ == "__main__": + unittest.main() diff --git a/fairseq/tests/test_iterators.py b/fairseq/tests/test_iterators.py new file mode 100644 index 0000000000000000000000000000000000000000..2e2eb2f0a8694c199f1d89344e0b9044ca649016 --- /dev/null +++ b/fairseq/tests/test_iterators.py @@ -0,0 +1,194 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from fairseq.data import iterators, ListDataset + + +class TestIterators(unittest.TestCase): + def test_counting_iterator_index(self, ref=None, itr=None): + # Test the indexing functionality of CountingIterator + if ref is None: + assert itr is None + ref = list(range(10)) + itr = iterators.CountingIterator(ref) + else: + assert len(ref) == 10 + assert itr is not None + + self.assertTrue(itr.has_next()) + self.assertEqual(itr.n, 0) + self.assertEqual(next(itr), ref[0]) + self.assertEqual(itr.n, 1) + self.assertEqual(next(itr), ref[1]) + self.assertEqual(itr.n, 2) + itr.skip(3) + self.assertEqual(itr.n, 5) + self.assertEqual(next(itr), ref[5]) + itr.skip(2) + self.assertEqual(itr.n, 8) + self.assertEqual(list(itr), [ref[8], ref[9]]) + self.assertFalse(itr.has_next()) + + def test_counting_iterator_length_mismatch(self): + ref = list(range(10)) + # When the underlying iterable is longer than the CountingIterator, + # the remaining items in the iterable should be ignored + itr = iterators.CountingIterator(ref, total=8) + self.assertEqual(list(itr), ref[:8]) + # When the underlying iterable is shorter than the CountingIterator, + # raise an IndexError when the underlying iterable is exhausted + itr = iterators.CountingIterator(ref, total=12) + self.assertRaises(IndexError, list, itr) + + def test_counting_iterator_take(self): + # Test the "take" method of CountingIterator + ref = list(range(10)) + itr = iterators.CountingIterator(ref) + itr.take(5) + self.assertEqual(len(itr), len(list(iter(itr)))) + self.assertEqual(len(itr), 5) + + itr = iterators.CountingIterator(ref) + itr.take(5) + self.assertEqual(next(itr), ref[0]) + self.assertEqual(next(itr), ref[1]) + itr.skip(2) + self.assertEqual(next(itr), ref[4]) + self.assertFalse(itr.has_next()) + + def test_grouped_iterator(self): + # test correctness + x = list(range(10)) + itr = iterators.GroupedIterator(x, 1) + self.assertEqual(list(itr), [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]) + itr = iterators.GroupedIterator(x, 4) + self.assertEqual(list(itr), [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]]) + itr = iterators.GroupedIterator(x, 5) + self.assertEqual(list(itr), [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) + + # test the GroupIterator also works correctly as a CountingIterator + x = list(range(30)) + ref = list(iterators.GroupedIterator(x, 3)) + itr = iterators.GroupedIterator(x, 3) + self.test_counting_iterator_index(ref, itr) + + def test_sharded_iterator(self): + # test correctness + x = list(range(10)) + itr = iterators.ShardedIterator(x, num_shards=1, shard_id=0) + self.assertEqual(list(itr), x) + itr = iterators.ShardedIterator(x, num_shards=2, shard_id=0) + self.assertEqual(list(itr), [0, 2, 4, 6, 8]) + itr = iterators.ShardedIterator(x, num_shards=2, shard_id=1) + self.assertEqual(list(itr), [1, 3, 5, 7, 9]) + itr = iterators.ShardedIterator(x, num_shards=3, shard_id=0) + self.assertEqual(list(itr), [0, 3, 6, 9]) + itr = iterators.ShardedIterator(x, num_shards=3, shard_id=1) + self.assertEqual(list(itr), [1, 4, 7, None]) + itr = iterators.ShardedIterator(x, num_shards=3, shard_id=2) + self.assertEqual(list(itr), [2, 5, 8, None]) + + # test CountingIterator functionality + x = list(range(30)) + ref = list(iterators.ShardedIterator(x, num_shards=3, shard_id=0)) + itr = iterators.ShardedIterator(x, num_shards=3, shard_id=0) + self.test_counting_iterator_index(ref, itr) + + def test_counting_iterator_buffered_iterator_take(self): + ref = list(range(10)) + buffered_itr = iterators.BufferedIterator(2, ref) + itr = iterators.CountingIterator(buffered_itr) + itr.take(5) + self.assertEqual(len(itr), len(list(iter(itr)))) + self.assertEqual(len(itr), 5) + + buffered_itr = iterators.BufferedIterator(2, ref) + itr = iterators.CountingIterator(buffered_itr) + itr.take(5) + self.assertEqual(len(buffered_itr), 5) + self.assertEqual(len(list(iter(buffered_itr))), 5) + + buffered_itr = iterators.BufferedIterator(2, ref) + itr = iterators.CountingIterator(buffered_itr) + itr.take(5) + self.assertEqual(next(itr), ref[0]) + self.assertEqual(next(itr), ref[1]) + itr.skip(2) + self.assertEqual(next(itr), ref[4]) + self.assertFalse(itr.has_next()) + self.assertRaises(StopIteration, next, buffered_itr) + + ref = list(range(4, 10)) + buffered_itr = iterators.BufferedIterator(2, ref) + itr = iterators.CountingIterator(buffered_itr, start=4) + itr.take(5) + self.assertEqual(len(itr), 5) + self.assertEqual(len(buffered_itr), 1) + self.assertEqual(next(itr), ref[0]) + self.assertFalse(itr.has_next()) + self.assertRaises(StopIteration, next, buffered_itr) + + def test_epoch_batch_iterator_skip_remainder_batch(self): + reference = [1, 2, 3] + itr1 = _get_epoch_batch_itr(reference, 2, True) + self.assertEqual(len(itr1), 1) + itr2 = _get_epoch_batch_itr(reference, 2, False) + self.assertEqual(len(itr2), 2) + itr3 = _get_epoch_batch_itr(reference, 1, True) + self.assertEqual(len(itr3), 2) + itr4 = _get_epoch_batch_itr(reference, 1, False) + self.assertEqual(len(itr4), 3) + itr5 = _get_epoch_batch_itr(reference, 4, True) + self.assertEqual(len(itr5), 0) + self.assertFalse(itr5.has_next()) + itr6 = _get_epoch_batch_itr(reference, 4, False) + self.assertEqual(len(itr6), 1) + + def test_grouped_iterator_skip_remainder_batch(self): + reference = [1, 2, 3, 4, 5, 6, 7, 8, 9] + itr1 = _get_epoch_batch_itr(reference, 3, False) + grouped_itr1 = iterators.GroupedIterator(itr1, 2, True) + self.assertEqual(len(grouped_itr1), 1) + + itr2 = _get_epoch_batch_itr(reference, 3, False) + grouped_itr2 = iterators.GroupedIterator(itr2, 2, False) + self.assertEqual(len(grouped_itr2), 2) + + itr3 = _get_epoch_batch_itr(reference, 3, True) + grouped_itr3 = iterators.GroupedIterator(itr3, 2, True) + self.assertEqual(len(grouped_itr3), 1) + + itr4 = _get_epoch_batch_itr(reference, 3, True) + grouped_itr4 = iterators.GroupedIterator(itr4, 2, False) + self.assertEqual(len(grouped_itr4), 1) + + itr5 = _get_epoch_batch_itr(reference, 5, True) + grouped_itr5 = iterators.GroupedIterator(itr5, 2, True) + self.assertEqual(len(grouped_itr5), 0) + + itr6 = _get_epoch_batch_itr(reference, 5, True) + grouped_itr6 = iterators.GroupedIterator(itr6, 2, False) + self.assertEqual(len(grouped_itr6), 1) + + +def _get_epoch_batch_itr(ref, bsz, skip_remainder_batch): + dsz = len(ref) + indices = range(dsz) + starts = indices[::bsz] + batch_sampler = [indices[s : s + bsz] for s in starts] + dataset = ListDataset(ref) + itr = iterators.EpochBatchIterator( + dataset=dataset, + collate_fn=dataset.collater, + batch_sampler=batch_sampler, + skip_remainder_batch=skip_remainder_batch, + ) + return itr.next_epoch_itr() + + +if __name__ == "__main__": + unittest.main() diff --git a/fairseq/tests/test_lm_context_window.py b/fairseq/tests/test_lm_context_window.py new file mode 100644 index 0000000000000000000000000000000000000000..165e04ac3a55c1337b615110e2c695948689e785 --- /dev/null +++ b/fairseq/tests/test_lm_context_window.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch + +from fairseq.data import MonolingualDataset +from fairseq.tasks.language_modeling import LanguageModelingConfig, LanguageModelingTask +from tests import utils as test_utils + + +class TestLMContextWindow(unittest.TestCase): + def test_eval_dataloader(self): + dictionary = test_utils.dummy_dictionary(10) + assert len(dictionary) == 14 # 4 extra special symbols + assert dictionary.pad() == 1 + + dataset = test_utils.TestDataset( + [ + torch.tensor([4, 5, 6, 7], dtype=torch.long), + torch.tensor([8, 9, 10, 11], dtype=torch.long), + torch.tensor([12, 13], dtype=torch.long), + ] + ) + dataset = MonolingualDataset(dataset, sizes=[4, 4, 2], src_vocab=dictionary) + + config = LanguageModelingConfig(tokens_per_sample=4) + task = LanguageModelingTask(config, dictionary) + + eval_dataloader = task.eval_lm_dataloader( + dataset=dataset, + batch_size=1, + context_window=2, + num_workers=0, + ) + + batch = next(eval_dataloader) + assert batch["net_input"]["src_tokens"][0].tolist() == [4, 5, 6, 7, 1, 1] + assert batch["target"][0].tolist() == [4, 5, 6, 7, 1, 1] + + batch = next(eval_dataloader) + assert batch["net_input"]["src_tokens"][0].tolist() == [6, 7, 8, 9, 10, 11] + assert batch["target"][0].tolist() == [1, 1, 8, 9, 10, 11] + + batch = next(eval_dataloader) + assert batch["net_input"]["src_tokens"][0].tolist() == [10, 11, 12, 13] + assert batch["target"][0].tolist() == [1, 1, 12, 13] + + +if __name__ == "__main__": + unittest.main() diff --git a/fairseq/tests/test_lstm_jitable.py b/fairseq/tests/test_lstm_jitable.py new file mode 100644 index 0000000000000000000000000000000000000000..38f79d17931c32447e96c0fbae2630ac397e1804 --- /dev/null +++ b/fairseq/tests/test_lstm_jitable.py @@ -0,0 +1,115 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import tempfile +import unittest + +import torch +from fairseq.data.dictionary import Dictionary +from fairseq.models.lstm import LSTMModel +from fairseq.tasks.fairseq_task import LegacyFairseqTask + + +DEFAULT_TEST_VOCAB_SIZE = 100 + + +class DummyTask(LegacyFairseqTask): + def __init__(self, args): + super().__init__(args) + self.dictionary = get_dummy_dictionary() + if getattr(self.args, "ctc", False): + self.dictionary.add_symbol("") + self.src_dict = self.dictionary + self.tgt_dict = self.dictionary + + @property + def source_dictionary(self): + return self.src_dict + + @property + def target_dictionary(self): + return self.dictionary + + +def get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE): + dummy_dict = Dictionary() + # add dummy symbol to satisfy vocab size + for id, _ in enumerate(range(vocab_size)): + dummy_dict.add_symbol("{}".format(id), 1000) + return dummy_dict + + +def get_dummy_task_and_parser(): + """ + to build a fariseq model, we need some dummy parse and task. This function + is used to create dummy task and parser to faciliate model/criterion test + + Note: we use FbSpeechRecognitionTask as the dummy task. You may want + to use other task by providing another function + """ + parser = argparse.ArgumentParser( + description="test_dummy_s2s_task", argument_default=argparse.SUPPRESS + ) + DummyTask.add_args(parser) + args = parser.parse_args([]) + task = DummyTask.setup_task(args) + return task, parser + + +class TestJitLSTMModel(unittest.TestCase): + def _test_save_and_load(self, scripted_module): + with tempfile.NamedTemporaryFile() as f: + scripted_module.save(f.name) + torch.jit.load(f.name) + + def assertTensorEqual(self, t1, t2): + t1 = t1[~torch.isnan(t1)] # can cause size mismatch errors if there are NaNs + t2 = t2[~torch.isnan(t2)] + self.assertEqual(t1.size(), t2.size(), "size mismatch") + self.assertEqual(t1.ne(t2).long().sum(), 0) + + def test_jit_and_export_lstm(self): + task, parser = get_dummy_task_and_parser() + LSTMModel.add_args(parser) + args = parser.parse_args([]) + args.criterion = "" + model = LSTMModel.build_model(args, task) + scripted_model = torch.jit.script(model) + self._test_save_and_load(scripted_model) + + def test_assert_jit_vs_nonjit_(self): + task, parser = get_dummy_task_and_parser() + LSTMModel.add_args(parser) + args = parser.parse_args([]) + args.criterion = "" + model = LSTMModel.build_model(args, task) + model.eval() + scripted_model = torch.jit.script(model) + scripted_model.eval() + idx = len(task.source_dictionary) + iter = 100 + # Inject random input and check output + seq_len_tensor = torch.randint(1, 10, (iter,)) + num_samples_tensor = torch.randint(1, 10, (iter,)) + for i in range(iter): + seq_len = seq_len_tensor[i] + num_samples = num_samples_tensor[i] + src_token = (torch.randint(0, idx, (num_samples, seq_len)),) + src_lengths = torch.randint(1, seq_len + 1, (num_samples,)) + src_lengths, _ = torch.sort(src_lengths, descending=True) + # Force the first sample to have seq_len + src_lengths[0] = seq_len + prev_output_token = (torch.randint(0, idx, (num_samples, 1)),) + result = model(src_token[0], src_lengths, prev_output_token[0], None) + scripted_result = scripted_model( + src_token[0], src_lengths, prev_output_token[0], None + ) + self.assertTensorEqual(result[0], scripted_result[0]) + self.assertTensorEqual(result[1], scripted_result[1]) + + +if __name__ == "__main__": + unittest.main()