diff --git a/fairseq/fairseq.egg-info/dependency_links.txt b/fairseq/fairseq.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/fairseq/fairseq.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/fairseq/fairseq/__pycache__/__init__.cpython-310.pyc b/fairseq/fairseq/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f7bca3f88122a8749d4a65decc80195ccf5bc3c Binary files /dev/null and b/fairseq/fairseq/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/fairseq/__pycache__/checkpoint_utils.cpython-310.pyc b/fairseq/fairseq/__pycache__/checkpoint_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e1b42a035466062250b8a663745d890e4692d8d Binary files /dev/null and b/fairseq/fairseq/__pycache__/checkpoint_utils.cpython-310.pyc differ diff --git a/fairseq/fairseq/__pycache__/file_chunker_utils.cpython-310.pyc b/fairseq/fairseq/__pycache__/file_chunker_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39f17c6fae7d9eda8b83166591a1529eb8410ae1 Binary files /dev/null and b/fairseq/fairseq/__pycache__/file_chunker_utils.cpython-310.pyc differ diff --git a/fairseq/fairseq/__pycache__/file_io.cpython-310.pyc b/fairseq/fairseq/__pycache__/file_io.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e1446991df4ad88a0ca03059c320a65e1a39b52 Binary files /dev/null and b/fairseq/fairseq/__pycache__/file_io.cpython-310.pyc differ diff --git a/fairseq/fairseq/__pycache__/file_utils.cpython-310.pyc b/fairseq/fairseq/__pycache__/file_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1e32bf53c348c80ef64d4a080af26224781168b Binary files /dev/null and b/fairseq/fairseq/__pycache__/file_utils.cpython-310.pyc differ diff --git a/fairseq/fairseq/__pycache__/hub_utils.cpython-310.pyc b/fairseq/fairseq/__pycache__/hub_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8276ff5c3eb2cabb9ef3de3ffa0bdfd89a8690c7 Binary files /dev/null and b/fairseq/fairseq/__pycache__/hub_utils.cpython-310.pyc differ diff --git a/fairseq/fairseq/__pycache__/options.cpython-310.pyc b/fairseq/fairseq/__pycache__/options.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4b0d9175ad504fffd01aed955522c26c443d597 Binary files /dev/null and b/fairseq/fairseq/__pycache__/options.cpython-310.pyc differ diff --git a/fairseq/fairseq/__pycache__/quantization_utils.cpython-310.pyc b/fairseq/fairseq/__pycache__/quantization_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..165fd0a0e50ca9e27a27c7765add6f830e5cd734 Binary files /dev/null and b/fairseq/fairseq/__pycache__/quantization_utils.cpython-310.pyc differ diff --git a/fairseq/fairseq/__pycache__/registry.cpython-310.pyc b/fairseq/fairseq/__pycache__/registry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6d4e09e4011aa8f0b6c768a3aa04a80bb29b024 Binary files /dev/null and b/fairseq/fairseq/__pycache__/registry.cpython-310.pyc differ diff --git a/fairseq/fairseq/__pycache__/search.cpython-310.pyc b/fairseq/fairseq/__pycache__/search.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79e88ece2b654236c33ecfbd4d68aaf8569919d8 Binary files /dev/null and b/fairseq/fairseq/__pycache__/search.cpython-310.pyc differ diff --git a/fairseq/fairseq/__pycache__/sequence_generator.cpython-310.pyc b/fairseq/fairseq/__pycache__/sequence_generator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..910d748c9ab49a8f7ad7b0f70bfce809cab4308b Binary files /dev/null and b/fairseq/fairseq/__pycache__/sequence_generator.cpython-310.pyc differ diff --git a/fairseq/fairseq/__pycache__/speech_generator.cpython-310.pyc b/fairseq/fairseq/__pycache__/speech_generator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29c162d46d3b0a5e765a13e6708be6b0c6ee02bf Binary files /dev/null and b/fairseq/fairseq/__pycache__/speech_generator.cpython-310.pyc differ diff --git a/fairseq/fairseq/__pycache__/token_generation_constraints.cpython-310.pyc b/fairseq/fairseq/__pycache__/token_generation_constraints.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5c0e262a1549e8a447586af75038be8e93d6e6d Binary files /dev/null and b/fairseq/fairseq/__pycache__/token_generation_constraints.cpython-310.pyc differ diff --git a/fairseq/fairseq/__pycache__/tokenizer.cpython-310.pyc b/fairseq/fairseq/__pycache__/tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b7de5d60f5bf80ea3680b6de1772620ab6e991f Binary files /dev/null and b/fairseq/fairseq/__pycache__/tokenizer.cpython-310.pyc differ diff --git a/fairseq/fairseq/__pycache__/utils.cpython-310.pyc b/fairseq/fairseq/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ebff13b5e28af12b10f631894edabb30ad75c68 Binary files /dev/null and b/fairseq/fairseq/__pycache__/utils.cpython-310.pyc differ diff --git a/fairseq/fairseq/__pycache__/version.cpython-310.pyc b/fairseq/fairseq/__pycache__/version.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..377324f91501ce9fe3fddad52d83df423829ade5 Binary files /dev/null and b/fairseq/fairseq/__pycache__/version.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/__init__.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e286869a3cc6ab44975cc44aabe83c8422f86fd Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/audio_classification.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/audio_classification.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8366daced92b62d1eb9ce166e34de1dc8dacba9a Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/audio_classification.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/audio_finetuning.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/audio_finetuning.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e10106fba3a26f3ea304f9cc3f74cfe741ed3911 Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/audio_finetuning.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/audio_pretraining.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/audio_pretraining.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b4f5d3dd421290d6a61d49ab2c0fab7ad0edbad Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/audio_pretraining.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/cross_lingual_lm.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/cross_lingual_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b038db28225fae96febbaa3c68649cdb28fbebe Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/cross_lingual_lm.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/denoising.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/denoising.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5e6c1994e1cda9668ff7cbbf6d027913d22f913 Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/denoising.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/fairseq_task.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/fairseq_task.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe388942439b28063e534b379c5ca0ca94c89cea Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/fairseq_task.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/frm_text_to_speech.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/frm_text_to_speech.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4397960d93beab0e435b8056b82374bc1a631f14 Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/frm_text_to_speech.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/hubert_pretraining.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/hubert_pretraining.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b37556b88f8b2c615fedb3825ba0487ca162143 Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/hubert_pretraining.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/language_modeling.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/language_modeling.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8630e0cd640abdc519bb432c07e52211ca8d6af Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/language_modeling.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/legacy_masked_lm.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/legacy_masked_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36d5eb6f70bdc24068158a6231a3b7d5ae9108dd Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/legacy_masked_lm.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/masked_lm.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/masked_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55084e6d50279405f2d3c71b409dddc81503e572 Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/masked_lm.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/multilingual_denoising.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/multilingual_denoising.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..769bdf578d287b0afcbe4e530fad323871be155a Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/multilingual_denoising.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/multilingual_language_modeling.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/multilingual_language_modeling.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd3d43f41beae47f4957c0f9b33ca9457caf9719 Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/multilingual_language_modeling.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/multilingual_masked_lm.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/multilingual_masked_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87f6ba890a2b18c5a724046535e9a6f655065417 Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/multilingual_masked_lm.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/multilingual_translation.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/multilingual_translation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdb567c9ce3ba4fc6225f6ee7a1f1105cf60b852 Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/multilingual_translation.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/multires_hubert_pretraining.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/multires_hubert_pretraining.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbe193395dfbe66a459e53d37802e21d363d4252 Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/multires_hubert_pretraining.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/nlu_finetuning.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/nlu_finetuning.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..996f2ee42ccf9200ee054a998b94b5baee908b63 Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/nlu_finetuning.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/online_backtranslation.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/online_backtranslation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92c3bb11f65f271f2b8609427562f7dd85281e88 Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/online_backtranslation.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/semisupervised_translation.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/semisupervised_translation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48e91dd84f2b6ff34938935a4a3d6214e5d193fe Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/semisupervised_translation.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/sentence_prediction.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/sentence_prediction.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25670edd8c70ab59ae40f61ffb6b887ed78c4d56 Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/sentence_prediction.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/sentence_prediction_adapters.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/sentence_prediction_adapters.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4a9bd76bb0621045b68012aeef6c13d050dc614 Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/sentence_prediction_adapters.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/sentence_ranking.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/sentence_ranking.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02354571ca82a8638586acc152b612aee6405e75 Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/sentence_ranking.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/simultaneous_translation.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/simultaneous_translation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d3f07192c1d16a9158a63679ac5576e20a94f10 Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/simultaneous_translation.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/span_masked_lm.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/span_masked_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29920b771929667512ad3bcab5f85ddeae8ee91e Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/span_masked_lm.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/speech_dlm_task.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/speech_dlm_task.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99be9ed582f4859238258aacfe211d0d1d6ab2ae Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/speech_dlm_task.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/speech_to_speech.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/speech_to_speech.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cc7d3b206b6d00fc38f68e652d134c6e966b7a3 Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/speech_to_speech.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/speech_to_text.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/speech_to_text.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af6db4f3cf057822c7a2b3b906009cc4b0c51d5c Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/speech_to_text.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/speech_ulm_task.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/speech_ulm_task.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e2a071f5e0d017bac02dab7ec0fec5c9ef633ef Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/speech_ulm_task.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/text_to_speech.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/text_to_speech.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5dc8ed2e6dfccd2c72f32a9f890d9bdfdf7f595 Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/text_to_speech.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/translation.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/translation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..832f0dab8e9dca5912d711f8b34d6e633d5781af Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/translation.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/translation_from_pretrained_bart.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/translation_from_pretrained_bart.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..831f87759bec3be601d1204b173fb8f4c10abf25 Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/translation_from_pretrained_bart.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/translation_from_pretrained_xlm.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/translation_from_pretrained_xlm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a02c8caa494444234961971bc56acb9ade62652 Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/translation_from_pretrained_xlm.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/translation_lev.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/translation_lev.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00ae9095d42ef22a05766c97098954019de660fa Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/translation_lev.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/__pycache__/translation_multi_simple_epoch.cpython-310.pyc b/fairseq/fairseq/tasks/__pycache__/translation_multi_simple_epoch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5bd6f53d1434ea0cc5f8e4d686561bf3dee2c59 Binary files /dev/null and b/fairseq/fairseq/tasks/__pycache__/translation_multi_simple_epoch.cpython-310.pyc differ diff --git a/fairseq/fairseq/tasks/multires_hubert_pretraining.py b/fairseq/fairseq/tasks/multires_hubert_pretraining.py new file mode 100644 index 0000000000000000000000000000000000000000..cfed147cb59965a55f52a0ffbc0c3d382c2b45bd --- /dev/null +++ b/fairseq/fairseq/tasks/multires_hubert_pretraining.py @@ -0,0 +1,204 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +import os +import sys +from typing import Dict, List, Optional, Tuple + +import numpy as np + +from dataclasses import dataclass, field +from fairseq.data import Dictionary, HubertDataset +from fairseq.dataclass.configs import FairseqDataclass +from fairseq.tasks import register_task +from fairseq.tasks.fairseq_task import FairseqTask +from omegaconf import MISSING + +logger = logging.getLogger(__name__) + + +class LabelEncoder(object): + def __init__(self, dictionary: Dictionary) -> None: + self.dictionary = dictionary + + def __call__(self, label: str) -> List[str]: + return self.dictionary.encode_line( + label, + append_eos=False, + add_if_not_exist=False, + ) + + +@dataclass +class MultiresHubertPretrainingConfig(FairseqDataclass): + data: str = field(default=MISSING, metadata={"help": "path to data directory"}) + fine_tuning: bool = field( + default=False, metadata={"help": "set to true if fine-tuning Hubert"} + ) + labels: List[str] = field( + default_factory=lambda: ["ltr50", "ltr25"], + metadata={ + "help": ( + "extension of the label files to load, frame-level labels for" + " pre-training, and sequence-level label for fine-tuning" + ) + }, + ) + label_dir: Optional[str] = field( + default=None, + metadata={ + "help": "if set, looks for labels in this directory instead", + }, + ) + label_rate: float = field( + default=-1.0, + metadata={"help": "label frame rate. -1.0 for sequence label"}, + ) + # label_rate: 1,2,2,5 + # (imply (1,2), (2,5)) + # if base label_rate = 50 + # (1,2), (2,5) --> label rates 50, 25, 10 + label_rate_ratios: List[int] = field(default=MISSING, metadata={"help": "tuple for label rates e.g., [(1,2), (2,5)]"}) + sample_rate: int = field( + default=16_000, + metadata={ + "help": "target sample rate. audio files will be up/down " + "sampled to this rate" + }, + ) + normalize: bool = field( + default=False, + metadata={"help": "if set, normalizes input to have 0 mean and unit variance"}, + ) + enable_padding: bool = field( + default=False, + metadata={"help": "pad shorter samples instead of cropping"}, + ) + max_keep_size: Optional[int] = field( + default=None, + metadata={"help": "exclude sample longer than this"}, + ) + max_sample_size: Optional[int] = field( + default=None, + metadata={"help": "max sample size to crop to for batching"}, + ) + min_sample_size: Optional[int] = field( + default=None, + metadata={"help": "min sample size to crop to for batching"}, + ) + random_crop: Optional[bool] = field( + default=True, + metadata={"help": "always crop from the beginning if false"}, + ) + pad_audio: Optional[bool] = field( + default=False, + metadata={"help": "pad audio to the longest one in the batch if true"}, + ) + + +@register_task("multires_hubert_pretraining", dataclass=MultiresHubertPretrainingConfig) +class MultiresHubertPretrainingTask(FairseqTask): + """ + Multiresolution HuBERT Pretraining Task. + The task is based on `HubertPretrainingTask` but extended to multiresolution. + """ + + cfg: MultiresHubertPretrainingConfig + + def __init__( + self, + cfg: MultiresHubertPretrainingConfig, + ) -> None: + super().__init__(cfg) + + logger.info(f"current directory is {os.getcwd()}") + logger.info(f"MultiresHubertPretrainingTask Config {cfg}") + + self.cfg = cfg + self.fine_tuning = cfg.fine_tuning + + if cfg.fine_tuning: + self.state.add_factory("target_dictionary", self.load_dictionaries) + self.res_number = 1 + else: + self.state.add_factory("dictionaries", self.load_dictionaries) + + self.blank_symbol = "" + + @property + def source_dictionary(self) -> Optional[Dictionary]: + return None + + @property + def target_dictionary(self) -> Optional[Dictionary]: + return self.state.target_dictionary + + @property + def dictionaries(self) -> List[Dictionary]: + return self.state.dictionaries + + @classmethod + def setup_task( + cls, cfg: MultiresHubertPretrainingConfig, **kwargs + ) -> "MultiresHubertPretrainingTask": + return cls(cfg) + + def load_dictionaries(self): + label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir + self.res_number = len(label_dir) + dictionaries = [ (Dictionary.load(f"{label_dir}/dict.{label}.txt") if label is not "" else None ) for label in self.cfg.labels] + return dictionaries[0] if self.cfg.fine_tuning else dictionaries + + def get_label_dir(self) -> str: + if self.cfg.label_dir is None: + return self.cfg.data + return self.cfg.label_dir + + def load_dataset(self, split: str, **kwargs) -> None: + manifest = f"{self.cfg.data}/{split}.tsv" + dicts = [self.target_dictionary] if self.cfg.fine_tuning else self.dictionaries + pad_list = [(dict.pad() if dict is not None else None) for dict in dicts] + eos_list = [(dict.eos() if dict is not None else None) for dict in dicts] + procs = [LabelEncoder(dict) for dict in dicts] + paths = [(f"{self.get_label_dir()}/{split}.{l}" if l != "" else None) for l in self.cfg.labels] + + base_rate = self.cfg.label_rate + self.label_rates = [base_rate] + label_rate_ratios = self.cfg.label_rate_ratios + self.label_rate_ratios = [] + for i in range(len(label_rate_ratios) // 2): + + upsample_rate, downsample_rate = label_rate_ratios[i * 2], label_rate_ratios[i * 2 + 1] + # parse label rate ratios + self.label_rate_ratios.append((upsample_rate, downsample_rate)) + base_rate = base_rate * upsample_rate // downsample_rate + self.label_rates.append(base_rate) + + # hubert v1: pad_audio=True, random_crop=False; + self.datasets[split] = HubertDataset( + manifest, + sample_rate=self.cfg.sample_rate, + label_paths=paths, + label_rates=self.label_rates, + pad_list=pad_list, + eos_list=eos_list, + label_processors=procs, + max_keep_sample_size=self.cfg.max_keep_size, + min_keep_sample_size=self.cfg.min_sample_size, + max_sample_size=self.cfg.max_sample_size, + pad_audio=self.cfg.pad_audio, + normalize=self.cfg.normalize, + store_labels=False, + random_crop=self.cfg.random_crop, + ) + + def max_positions(self) -> Tuple[int, int]: + return (sys.maxsize, sys.maxsize) + + def filter_indices_by_size(self, indices: np.array, *args, **kwargs) -> np.array: + return indices diff --git a/fairseq/fairseq/tasks/nlu_finetuning.py b/fairseq/fairseq/tasks/nlu_finetuning.py new file mode 100644 index 0000000000000000000000000000000000000000..a335021335a417aaaf6e6a3b3a02f525ed933a46 --- /dev/null +++ b/fairseq/fairseq/tasks/nlu_finetuning.py @@ -0,0 +1,477 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +import os +import torch +import json + +from argparse import Namespace +from dataclasses import dataclass, field +from typing import Optional, Any + +from fairseq.data import AddTargetDataset, Dictionary, encoders +from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.configs import GenerationConfig +from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel + +from . import register_task +from .. import utils +from ..logging import metrics + + +logger = logging.getLogger(__name__) + + +class LabelEncoder(object): + def __init__(self, dictionary): + self.dictionary = dictionary + + def __call__(self, label): + return self.dictionary.encode_line( + label, append_eos=False, add_if_not_exist=False + ) + + +def label_len_fn(label): + return len(label.split(" ")) + + +@dataclass +class NLUFinetuningConfig(AudioPretrainingConfig): + # Options for reporting WER metrics during validation. Only applicable to + # Seq2Seq models during fine-tuning + eval_wer: bool = field( + default=False, metadata={"help": "compute WER for Seq2Seq models"} + ) + eval_wer_parse: bool = field( + default=False, metadata={"help": "compute WER for Seq2Seq models"} + ) + eval_wer_config: GenerationConfig = field( + default_factory=lambda: GenerationConfig(), + metadata={"help": "beam search config for evaluating wer during training"}, + ) + eval_wer_tokenizer: Any = field( + default=None, + metadata={"help": "tokenizer config for evaluating wer during training"}, + ) + eval_wer_post_process: str = field( + default="letter", + metadata={ + "help": "remove BPE tokens before scoring (can be sentencepiece, letter, and more)" + }, + ) + eval_bleu: bool = field( + default=False, metadata={"help": "evaluation with BLEU scores"} + ) + eval_bleu_detok: Optional[str] = field( + default=None, + metadata={ + "help": "detokenize before computing BLEU (e.g., 'moses'); " + "required if using --eval-bleu; use 'space' to disable " + "detokenization; see fairseq.data.encoders for other options" + }, + ) + eval_bleu_detok_args: str = field( + default="{}", metadata={"help": "args for building the tokenizer, if needed"} + ) + eval_tokenized_bleu: bool = field( + default=False, metadata={"help": "compute tokenized BLEU instead of sacrebleu"} + ) + eval_bleu_remove_bpe: Optional[str] = field( + default=None, metadata={"help": "remove BPE before computing BLEU"} + ) + eval_bleu_args: str = field( + default="{}", + metadata={ + "help": "generation args for BLUE scoring, e.g., " + '\'{"beam": 4, "lenpen": 0.6}\'' + }, + ) + eval_bleu_print_samples: bool = field( + default=False, metadata={"help": "print sample generations during validation"} + ) + autoregressive: bool = field( + default=False, + metadata={ + "help": "required for autoregressive decoders (like seq2seq models); " + "adds 'prev_output_tokens' to input and appends eos to target" + }, + ) + + +@register_task("nlu_finetuning", dataclass=NLUFinetuningConfig) +class NLUFinetuningTask(AudioPretrainingTask): + """ """ + + cfg: NLUFinetuningConfig + + def __init__( + self, + cfg: NLUFinetuningConfig, + ): + super().__init__(cfg) + self.blank_symbol = "" + + self.state.add_factory("target_dictionary", self.load_target_dictionary) + + def load_target_dictionary(self): + if self.cfg.labels: + dict_path = os.path.join(self.cfg.data, f"dict.{self.cfg.labels}.txt") + return Dictionary.load(dict_path) + return None + + def load_dataset(self, split: str, task_cfg: NLUFinetuningConfig = None, **kwargs): + super().load_dataset(split, task_cfg, **kwargs) + + task_cfg = task_cfg or self.cfg + assert task_cfg.labels is not None + text_compression_level = getattr( + TextCompressionLevel, str(self.cfg.text_compression_level) + ) + data_path = self.cfg.data + label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") + skipped_indices = getattr(self.datasets[split], "skipped_indices", set()) + text_compressor = TextCompressor(level=text_compression_level) + with open(label_path, "r") as f: + labels = [ + text_compressor.compress(l) + for i, l in enumerate(f) + if i not in skipped_indices + ] + + assert len(labels) == len(self.datasets[split]), ( + f"labels length ({len(labels)}) and dataset length " + f"({len(self.datasets[split])}) do not match" + ) + + process_label = LabelEncoder(self.target_dictionary) + + self.datasets[split] = AddTargetDataset( + self.datasets[split], + labels, + pad=self.target_dictionary.pad(), + eos=self.target_dictionary.eos(), + batch_targets=True, + process_label=process_label, + label_len_fn=label_len_fn, + add_to_input=task_cfg.get("autoregressive", False), + text_compression_level=text_compression_level, + ) + + @property + def target_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self.state.target_dictionary + + def valid_step(self, sample, model, criterion): + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + if self.cfg.eval_wer_parse and self.cfg.autoregressive: + metrics = self._inference_with_wer_parse( + self.sequence_generator, sample, model + ) + logging_output["_num_char_errors"] = metrics["num_char_errors"] + logging_output["_num_chars"] = metrics["num_chars"] + logging_output["_num_word_errors"] = metrics["num_word_errors"] + logging_output["_num_words"] = metrics["num_words"] + logging_output["_num_em_errors"] = metrics["num_em_errors"] + logging_output["_num_ems"] = metrics["num_ems"] + logging_output["_num_tree_errors"] = metrics["num_tree_errors"] + logging_output["_num_trees"] = metrics["num_trees"] + if self.cfg.eval_wer and self.cfg.autoregressive: + metrics = self._inference_with_wer(self.sequence_generator, sample, model) + logging_output["_num_char_errors"] = metrics["num_char_errors"] + logging_output["_num_chars"] = metrics["num_chars"] + logging_output["_num_word_errors"] = metrics["num_word_errors"] + logging_output["_num_words"] = metrics["num_words"] + if self.cfg.eval_bleu and self.cfg.autoregressive: + metrics = self._inference_with_bleu(self.sequence_generator, sample, model) + logging_output["_bleu_sys_len"] = metrics.sys_len + logging_output["_bleu_ref_len"] = metrics.ref_len + # we split counts into separate entries so that they can be + # summed efficiently across workers using fast-stat-sync + assert len(metrics.counts) == 4 + for i in range(4): + logging_output[f"_bleu_counts_{i}"] = metrics.counts[i] + logging_output[f"_bleu_totals_{i}"] = metrics.totals[i] + return loss, sample_size, logging_output + + def build_model(self, model_cfg: FairseqDataclass): + model = super().build_model(model_cfg) + + if (self.cfg.eval_wer or self.cfg.eval_wer_parse) and self.cfg.autoregressive: + self.sequence_generator = self.build_generator( + [model], + self.cfg.eval_wer_config, + ) + if self.cfg.eval_wer_tokenizer: + self.tokenizer = encoders.build_tokenizer(self.cfg.eval_wer_tokenizer) + else: + self.tokenizer = None + if self.cfg.eval_bleu and self.cfg.autoregressive: + assert self.cfg.eval_bleu_detok is not None, ( + "--eval-bleu-detok is required if using --eval-bleu; " + "try --eval-bleu-detok=moses (or --eval-bleu-detok=space " + "to disable detokenization, e.g., when using sentencepiece)" + ) + detok_args = json.loads(self.cfg.eval_bleu_detok_args) + self.tokenizer = encoders.build_tokenizer( + Namespace(tokenizer=self.cfg.eval_bleu_detok, **detok_args) + ) + gen_args = json.loads(self.cfg.eval_bleu_args) + gen_args = Namespace(**gen_args) + self.sequence_generator = self.build_generator([model], gen_args) + + return model + + def _inference_with_wer_parse(self, generator, sample, model): + import editdistance + + def decode(toks): + s = self.target_dictionary.string( + toks.int().cpu(), + self.cfg.eval_wer_post_process, + escape_unk=True, + ) + if self.tokenizer: + s = self.tokenizer.decode(s) + return s + + def decode_to_list(toks): + def token_string(i): + if i == self.target_dictionary.unk(): + return self.target_dictionary.unk_string(False) + else: + return self.target_dictionary[i] + + return [token_string(i) for i in toks] + + def is_ont_token(token): + return "[" in token or "]" in token + + def post_process(l): + o = [] + for w in l: + if w == self.target_dictionary.eos_word or w == "|": + continue + if w == "_": + o.append(" ") + else: + o.append(w) + if is_ont_token(w): + o.append(" ") + return o + + num_word_errors, num_char_errors = 0, 0 + num_chars, num_words = 0, 0 + num_em_errors, num_ems = 0, 0 + num_tree_errors, num_trees = 0, 0 + gen_out = self.inference_step(generator, [model], sample, None) + for i in range(len(gen_out)): + hyp_tokens = gen_out[i][0]["tokens"] + # hyp = decode(hyp_tokens) + ref_tokens = utils.strip_pad( + sample["target"][i], self.target_dictionary.pad() + ) + # ref = decode(ref_tokens) + hyp_list = decode_to_list(hyp_tokens) + ref_list = decode_to_list(ref_tokens) + + hyp_list = post_process(hyp_list) + ref_list = post_process(ref_list) + + hyp = "".join(hyp_list).strip() + ref = "".join(ref_list).strip() + num_chars += len(ref) + num_char_errors += editdistance.eval(hyp, ref) + hyp_words = hyp.split() + ref_words = ref.split() + hyp_tree = [word for word in hyp_list if ("[" in word or "]" in word)] + ref_tree = [word for word in ref_list if ("[" in word or "]" in word)] + # num_word_errors += editdistance.eval(hyp_words, ref_words) + hyp_before = decode(hyp_tokens).split() + ref_before = decode(ref_tokens).split() + + num_word_errors += editdistance.eval(hyp_before, ref_before) + num_words += len(ref_before) + if hyp != ref: + num_em_errors += 1 + if hyp_tree != ref_tree: + num_tree_errors += 1 + num_ems += 1 + num_trees += 1 + + return { + "num_char_errors": num_char_errors, + "num_chars": num_chars, + "num_word_errors": num_word_errors, + "num_words": num_words, + "num_ems": num_ems, + "num_em_errors": num_em_errors, + "num_trees": num_trees, + "num_tree_errors": num_tree_errors, + } + + def _inference_with_wer(self, generator, sample, model): + import editdistance + + def decode(toks): + s = self.target_dictionary.string( + toks.int().cpu(), + self.cfg.eval_wer_post_process, + escape_unk=True, + ) + if self.tokenizer: + s = self.tokenizer.decode(s) + return s + + num_word_errors, num_char_errors = 0, 0 + num_chars, num_words = 0, 0 + gen_out = self.inference_step(generator, [model], sample, None) + for i in range(len(gen_out)): + hyp = decode(gen_out[i][0]["tokens"]) + ref = decode( + utils.strip_pad(sample["target"][i], self.target_dictionary.pad()), + ) + num_char_errors += editdistance.eval(hyp, ref) + num_chars += len(ref) + hyp_words = hyp.split() + ref_words = ref.split() + num_word_errors += editdistance.eval(hyp_words, ref_words) + num_words += len(ref_words) + + return { + "num_char_errors": num_char_errors, + "num_chars": num_chars, + "num_word_errors": num_word_errors, + "num_words": num_words, + } + + def _inference_with_bleu(self, generator, sample, model): + import sacrebleu + + def decode(toks, is_ref): + s = self.target_dictionary.string( + toks.int().cpu(), + self.cfg.eval_bleu_remove_bpe, + # The default unknown string in fairseq is ``, but + # this is tokenized by sacrebleu as `< unk >`, inflating + # BLEU scores. Instead, we use a somewhat more verbose + # alternative that is unlikely to appear in the real + # reference, but doesn't get split into multiple tokens. + unk_string=("UNKNOWNTOKENINREF" if is_ref else "UNKNOWNTOKENINHYP"), + ) + if self.tokenizer: + s = self.tokenizer.decode(s) + return s + + gen_out = self.inference_step(generator, [model], sample) + hyps, refs = [], [] + for i in range(len(gen_out)): + hyps.append(decode(gen_out[i][0]["tokens"], is_ref=False)) + refs.append( + decode( + utils.strip_pad(sample["target"][i], self.target_dictionary.pad()), + is_ref=True, # don't count as matches to the hypo + ) + ) + if self.cfg.eval_bleu_print_samples: + logger.info("H-{} {}".format(sample["id"][0], hyps[0])) + logger.info("T-{} {}".format(sample["id"][0], refs[0])) + + eval_tokenization = "none" if self.cfg.eval_tokenized_bleu else "13a" + return sacrebleu.corpus_bleu(hyps, [refs], tokenize=eval_tokenization) + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + + if self.cfg.eval_wer or self.cfg.eval_wer_parse: + zero = torch.scalar_tensor(0.0) + num_char_errors = sum( + log.get("_num_char_errors", zero) for log in logging_outputs + ) + num_chars = sum(log.get("_num_chars", zero) for log in logging_outputs) + num_word_errors = sum( + log.get("_num_word_errors", zero) for log in logging_outputs + ) + num_words = sum(log.get("_num_words", zero) for log in logging_outputs) + metrics.log_scalar("_num_char_errors", num_char_errors) + metrics.log_scalar("_num_chars", num_chars) + metrics.log_scalar("_num_word_errors", num_word_errors) + metrics.log_scalar("_num_words", num_words) + if num_chars > 0: + metrics.log_derived( + "uer", + lambda meters: meters["_num_char_errors"].sum + * 100.0 + / meters["_num_chars"].sum + if meters["_num_chars"].sum > 0 + else float("nan"), + ) + if num_words > 0: + metrics.log_derived( + "wer", + lambda meters: meters["_num_word_errors"].sum + * 100.0 + / meters["_num_words"].sum + if meters["_num_words"].sum > 0 + else float("nan"), + ) + if self.cfg.eval_wer_parse: + num_em_errors = sum( + log.get("_num_em_errors", zero) for log in logging_outputs + ) + num_ems = sum(log.get("_num_ems", zero) for log in logging_outputs) + metrics.log_scalar("_num_em_errors", num_em_errors) + metrics.log_scalar("_num_ems", num_ems) + num_tree_errors = sum( + log.get("_num_tree_errors", zero) for log in logging_outputs + ) + num_trees = sum(log.get("_num_trees", zero) for log in logging_outputs) + metrics.log_scalar("_num_tree_errors", num_tree_errors) + metrics.log_scalar("_num_trees", num_trees) + + if num_ems > 0: + metrics.log_derived( + "em_error", + lambda meters: meters["_num_em_errors"].sum + * 100.0 + / meters["_num_ems"].sum + if meters["_num_ems"].sum > 0 + else float("nan"), + ) + if num_trees > 0: + metrics.log_derived( + "tree_error", + lambda meters: meters["_num_tree_errors"].sum + * 100.0 + / meters["_num_trees"].sum + if meters["_num_trees"].sum > 0 + else float("nan"), + ) + + if self.cfg.eval_bleu: + len_keys = ["_bleu_sys_len", "_bleu_ref_len"] + count_keys = [f"_bleu_counts_{i}" for i in range(4)] + total_keys = [f"_bleu_totals_{i}" for i in range(4)] + for k in len_keys + count_keys + total_keys: + metrics.log_scalar(k, sum(log.get(k, 0) for log in logging_outputs)) + + import sacrebleu + + metrics.log_derived( + "bleu", + lambda meters: sacrebleu.compute_bleu( + correct=[meters[k].sum for k in count_keys], + total=[meters[k].sum for k in total_keys], + sys_len=meters["_bleu_sys_len"].sum, + ref_len=meters["_bleu_ref_len"].sum, + smooth_method="exp", + ).score, + ) diff --git a/fairseq/fairseq/tasks/online_backtranslation.py b/fairseq/fairseq/tasks/online_backtranslation.py new file mode 100644 index 0000000000000000000000000000000000000000..da24fe8981cd6a2b6f953b3e6646082c1758b0b5 --- /dev/null +++ b/fairseq/fairseq/tasks/online_backtranslation.py @@ -0,0 +1,683 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import json +import logging +import math +import os +from argparse import Namespace +from collections import OrderedDict, defaultdict +from pathlib import Path +from typing import Dict, Sequence, Tuple +from argparse import ArgumentError + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import fairseq +from fairseq import options, utils +from fairseq.logging import metrics +from fairseq.data import ( + FairseqDataset, + LanguagePairDataset, + NoisingDataset, + PrependTokenDataset, + RoundRobinZipDatasets, + TransformEosLangPairDataset, + data_utils, + encoders, +) +from fairseq.sequence_generator import SequenceGenerator +from fairseq.tasks import register_task +from fairseq.tasks.translation import TranslationTask, load_langpair_dataset + +logger = logging.getLogger(__name__) + + +class PiecewiseLinearFn: + """Piecewise linear function. Can be configured with a string.""" + + def __init__(self, pieces: Sequence[Tuple[int, float]]): + assert pieces == sorted( + pieces + ), f"PiecewiseLinearFn configuration should be sorted, received: {pieces}" + + self.pieces = pieces + + def __call__(self, x: int) -> float: + for i, (x_a, y_a) in enumerate(self.pieces[:-1]): + x_b, y_b = self.pieces[i + 1] + if x_a <= x <= x_b: + return y_a + (x - x_a) * (y_b - y_a) / (x_b - x_a) + + return self.pieces[-1][1] + + @staticmethod + def from_string(configuration: str) -> "PiecewiseLinearFn": + """ + Parse the configuration of lambda coefficient (for scheduling). + x = "3" # lambda will be a constant equal to x + x = "0:1,1000:0" # lambda will start from 1 and linearly decrease + # to 0 during the first 1000 iterations + x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000 + # iterations, then will linearly increase to 1 until iteration 2000 + """ + if isinstance(configuration, float): + return PiecewiseLinearFn([(0, configuration)]) + + try: + parts = configuration.split(",") + if len(parts) == 1: + v = float(configuration) + return PiecewiseLinearFn([(0, v)]) + + split = [s.split(":") for s in parts] + pieces = [(int(t), float(v)) for t, v in split] + return PiecewiseLinearFn(pieces) + except Exception: + raise ValueError( + f"Invalid PiecewiseLinearFn configuration: {configuration!r}" + ) + + @staticmethod + def one() -> "PiecewiseLinearFn": + return PiecewiseLinearFn([(0, 1.0)]) + + +@register_task("online_backtranslation") +class OnlineBackTranslationTask(TranslationTask): + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + # fmt: off + # Generic translation args + parser.add_argument('data', help='colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner; \ + however, valid and test data are always in the first directory to \ + avoid the need for repeating them in all directories') + parser.add_argument('--mono-langs', metavar='MONO_LANGS', + help='monolingual languages for training') + parser.add_argument('--valid-lang-pairs', default=None, metavar='VALID_LANG_PAIRS', + help='language pairs for validation') + parser.add_argument('--load-alignments', action='store_true', + help='load the binarized alignments') + parser.add_argument('--left-pad-source', default='False', type=str, metavar='BOOL', + help='pad the source on the left') + parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL', + help='pad the target on the left') + parser.add_argument('--upsample-primary', default=1, type=int, + help='amount to upsample primary dataset') + try: + parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N', + help='max number of tokens in the source sequence') + parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', + help='max number of tokens in the target sequence') + except ArgumentError: + # this might have already been defined. Once we transition this to hydra it should be fine to add it here. + pass + parser.add_argument('--truncate-source', action='store_true', default=False, + help='truncate source to max-source-positions') + parser.add_argument('--num-batch-buckets', default=0, type=int, metavar='N', + help='if >0, then bucket source and target lengths into N ' + 'buckets and pad accordingly; this is useful on TPUs ' + 'to minimize the number of compilations') + + # Denoising args + parser.add_argument('--max-word-shuffle-distance', default=3.0, type=float, metavar='N', + help='maximum word shuffle distance for denoising autoencoding data generation') + parser.add_argument('--word-dropout-prob', default=0.1, type=float, metavar='N', + help='word dropout probability for denoising autoencoding data generation') + parser.add_argument('--word-blanking-prob', default=0.2, type=float, metavar='N', + help='word blanking probability for denoising autoencoding data generation') + + # Backtranslation args + parser.add_argument('--lambda-bt', default="1.0", type=str, metavar='N', + help='back-translation weight') + parser.add_argument('--lambda-dae', default="1.0", type=str, metavar='N', + help='denoising auto-encoder weight') + + # Evaluation args + parser.add_argument('--generate-one-by-one', action='store_true', + help='generate one sentence at a time for backtranslation') + + parser.add_argument('--eval-bleu', action='store_true', + help='evaluation with BLEU scores') + parser.add_argument('--eval-bleu-detok', type=str, default="space", + help='detokenize before computing BLEU (e.g., "moses"); ' + 'required if using --eval-bleu; use "space" to ' + 'disable detokenization; see fairseq.data.encoders ' + 'for other options') + parser.add_argument('--eval-bleu-detok-args', type=str, metavar='JSON', + help='args for building the tokenizer, if needed') + parser.add_argument('--eval-tokenized-bleu', action='store_true', default=False, + help='compute tokenized BLEU instead of sacrebleu') + parser.add_argument('--eval-bleu-remove-bpe', nargs='?', const='@@ ', default=None, + help='remove BPE before computing BLEU') + parser.add_argument('--eval-bleu-args', type=str, metavar='JSON', + help='generation args for BLUE scoring, ' + 'e.g., \'{"beam": 4, "lenpen": 0.6}\'') + parser.add_argument('--eval-bleu-print-samples', action='store_true', + help='print sample generations during validation') + # fmt: on + + def __init__(self, args, common_dict, mono_langs, valid_lang_pairs): + super().__init__(args, common_dict, common_dict) + self.common_dict = common_dict + self.mono_langs = mono_langs + self.valid_lang_pairs = valid_lang_pairs + + self.SHOW_SAMPLES_INTERVAL = 1000 + # Start by showing samples + self._show_samples_ctr = self.SHOW_SAMPLES_INTERVAL + self.SHOW_SAMPLES_NUMBER = 5 + self.lambda_bt = PiecewiseLinearFn.from_string(args.lambda_bt) + self.lambda_dae = PiecewiseLinearFn.from_string(args.lambda_dae) + + self.args = args + self.data = utils.split_paths(self.args.data) + if len(self.data) == 1: + shards = list(Path(self.data[0]).glob("shard*")) + if len(shards) > 0: + # keep this as strings, since it can also be a manifold path + old_data = self.data + self.data = [str(shard) for shard in shards] + logging.warning(f"Expanded data directory {old_data} to {self.data}") + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + args.left_pad_source = options.eval_bool(args.left_pad_source) + args.left_pad_target = options.eval_bool(args.left_pad_target) + + paths = utils.split_paths(args.data) + assert len(paths) > 0 + assert args.mono_langs is not None + + mono_langs = args.mono_langs.split(",") + valid_lang_pairs = args.valid_lang_pairs.split(",") + + # load dictionary + dict_path = os.path.join(paths[0], "dict.txt") + common_dict = cls.load_dictionary(dict_path) + + return cls(args, common_dict, mono_langs, valid_lang_pairs) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs) -> FairseqDataset: + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + if split == "train": + data_path = self.data[(epoch - 1) % len(self.data)] + dataset = self.load_train_dataset(data_path) + else: + # valid/test should always be the same. + dataset = self.load_translation_dataset(split, self.data[0]) + + self.datasets[split] = dataset + return dataset + + def load_train_dataset(self, data_path: str) -> FairseqDataset: + """The training dataset is made of backtranslation dataset and denoising dataset.""" + data = [] + for lang in self.mono_langs: + train_path = os.path.join(data_path, lang, "train") + # TODO: could we do the BT using denoise sample ? + # this would half the data loading work + data.append((f"{lang}-BT", self.load_bt_dataset(train_path, lang))) + data.append( + (f"{lang}-DENOISE", self.load_denoise_dataset(train_path, lang)) + ) + + return RoundRobinZipDatasets(OrderedDict(data)) + + def _langpair_dataset( + self, src: FairseqDataset, tgt: FairseqDataset + ) -> LanguagePairDataset: + return LanguagePairDataset( + src, + src.sizes, + self.dictionary, + tgt=tgt, + tgt_sizes=tgt.sizes, + tgt_dict=self.dictionary, + left_pad_source=self.args.left_pad_source, + left_pad_target=self.args.left_pad_target, + # TODO: should we shuffle ? we are already sorting batch by sizes so ? + # shuffle=True, + ) + + def _prepend_lang_bos_to_target( + self, dataset: LanguagePairDataset, lang: str + ) -> LanguagePairDataset: + bos = _lang_token_index(self.dictionary, lang) + return TransformEosLangPairDataset( + dataset, + src_eos=self.dictionary.eos(), + new_src_eos=self.dictionary.eos(), + tgt_bos=self.dictionary.eos(), + new_tgt_bos=bos, + ) + + def load_bt_dataset(self, data_path: str, lang: str) -> FairseqDataset: + """The BT dataset is generated with (tgt, tgt) pairs. + The actual translation to a (generated_src, tgt) pair + is done on the fly during training. + """ + mono_dataset = data_utils.load_indexed_dataset( + data_path, self.common_dict, self.args.dataset_impl + ) + assert mono_dataset is not None, f"No dataset found for {lang}" + + mono_dataset_src = PrependTokenDataset( + mono_dataset, _lang_token_index(self.dictionary, lang) + ) + + mono_dataset_bt = self._langpair_dataset(mono_dataset_src, mono_dataset) + logger.info( + f"mono_lang = {lang} " + f"lang token index = {_lang_token_index(self.dictionary, lang)} " + f"lang token = {_lang_token(lang)}" + ) + + mono_dataset_bt = self._prepend_lang_bos_to_target(mono_dataset_bt, lang) + return mono_dataset_bt + + def load_denoise_dataset(self, data_path: str, lang: str) -> FairseqDataset: + """Classic denoising dataset""" + dataset = data_utils.load_indexed_dataset( + data_path, self.common_dict, self.args.dataset_impl + ) + noisy_dataset = NoisingDataset( + dataset, + self.dictionary, + seed=1, + max_word_shuffle_distance=self.args.max_word_shuffle_distance, + word_dropout_prob=self.args.word_dropout_prob, + word_blanking_prob=self.args.word_blanking_prob, + ) + noisy_dataset = PrependTokenDataset( + noisy_dataset, _lang_token_index(self.dictionary, lang) + ) + + clean_dataset = data_utils.load_indexed_dataset( + data_path, self.common_dict, self.args.dataset_impl + ) + denoising_dataset = self._langpair_dataset(noisy_dataset, clean_dataset) + denoising_dataset = self._prepend_lang_bos_to_target(denoising_dataset, lang) + return denoising_dataset + + def load_translation_dataset( + self, split: str, data_path: str, combine: bool = False + ): + # only judging with one language pair for the moment, + # since ConcatDataset doesn't work as expected + assert len(self.valid_lang_pairs) == 1, "For now..." + valid_lang_pair = self.valid_lang_pairs[0] + src, tgt = valid_lang_pair.split("-") + + # use the same function than TranslationTask + src_tgt_dt = load_langpair_dataset( + data_path, + split, + src, + self.common_dict, + tgt, + self.common_dict, + combine=combine, + dataset_impl=self.args.dataset_impl, + upsample_primary=self.args.upsample_primary, + left_pad_source=self.args.left_pad_source, + left_pad_target=self.args.left_pad_target, + max_source_positions=self.args.max_source_positions, + max_target_positions=self.args.max_target_positions, + load_alignments=self.args.load_alignments, + truncate_source=self.args.truncate_source, + num_buckets=self.args.num_batch_buckets, + shuffle=(split != "test"), + prepend_bos_src=_lang_token_index(self.dictionary, src), + ) + + src_tgt_eos_dt = self._prepend_lang_bos_to_target(src_tgt_dt, tgt) + src_tgt_eos_dt.args = self.args + return src_tgt_eos_dt + + def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): + raise NotImplementedError + + def build_model(self, args, from_checkpoint=False): + # torch.autograd.set_detect_anomaly(True) + model = super().build_model(args, from_checkpoint) + + add_secial_tokens_to_dict_and_model(self.common_dict, model, self.mono_langs) + + self.sequence_generators = {} + for mono_lang in self.mono_langs: + self.sequence_generators[mono_lang] = SequenceGenerator( + [model], + tgt_dict=self.dictionary, + beam_size=1, + max_len_a=1.3, + max_len_b=5, + min_len=5, + # keep 1 to be able to prepend bos + max_len=model.max_decoder_positions() - 1, + ) + + if getattr(args, "eval_bleu", False): + assert getattr(args, "eval_bleu_detok", None) is not None, ( + "--eval-bleu-detok is required if using --eval-bleu; " + "try --eval-bleu-detok=moses (or --eval-bleu-detok=space " + "to disable detokenization, e.g., when using sentencepiece)" + ) + detok_args = json.loads(getattr(args, "eval_bleu_detok_args", "{}") or "{}") + self.tokenizer = encoders.build_tokenizer( + Namespace( + tokenizer=getattr(args, "eval_bleu_detok", None), **detok_args + ) + ) + + gen_args = json.loads(getattr(args, "eval_bleu_args", "{}") or "{}") + self.bleu_sequence_generator = self.build_generator( + [model], Namespace(**gen_args) + ) + + return model + + def max_positions(self): + """Return the max sentence length allowed by the task.""" + return (self.args.max_source_positions, self.args.max_target_positions) + + @property + def dictionary(self): + """Return the source :class:`~fairseq.data.Dictionary`.""" + return self.common_dict + + def display_samples_once_in_a_while(self, smp, mono_lang, other_lang): + self._show_samples_ctr += 1 + if self._show_samples_ctr < self.SHOW_SAMPLES_INTERVAL: + return + self._show_samples_ctr = 0 + + ln = smp["net_input"]["src_tokens"].shape[0] + + logger.info( + f"(r:{self.args.distributed_rank}) : " + f"{other_lang} ---> {mono_lang} " + f"({other_lang} was generated by back-translation.) {ln} samples" + ) + + for i in range(min(ln, self.SHOW_SAMPLES_NUMBER)): + src_tokens = smp["net_input"]["src_tokens"][i] + tgt_tokens = smp["target"][i] + + src_str = self.dictionary.string(src_tokens, "sentencepiece") + tgt_str = self.dictionary.string(tgt_tokens, "sentencepiece") + logger.info( + f"\n{i}\t\t[{other_lang} generated] {src_str}\n" + f"\t\t[{mono_lang} original ] {tgt_str}\n" + f"\t\t[ src tokens] {src_tokens}\n" + ) + + def backtranslate_sample(self, smp, orig_lang, other_lang) -> None: + """ + * WARNING: smp is modified in place. + * At the start of this function, `smp` has the same input and target: + |--------------------------------------------------------| + | smp['net_input']['src_tokens'] | smp['target'] | + | (from data) __en__ hello world | __en__ hello world | + |--------------------------------------------------------| + + * We call generator.generate(smp, bos_token = token("ro")), + and copy the result as input + * At the end, `smp` has the translation to other language. + |--------------------------------------------------------| + | smp['net_input']['src_tokens'] | smp['target'] | + | (generated) __ro__ salut lume | __en__ hello world | + |--------------------------------------------------------| + + """ + bos_token = _lang_token_index(self.dictionary, other_lang) + generated = self.sequence_generators[orig_lang].generate( + models=[], sample=smp, bos_token=bos_token + ) + + max_lngth = max([gn[0]["tokens"].size(0) for gn in generated]) + net_input = smp["net_input"] + n_src_tokens = torch.empty( + size=(len(generated), max_lngth + 1), dtype=net_input["src_tokens"].dtype + ) + n_src_lengths = torch.empty( + len(generated), dtype=net_input["src_lengths"].dtype + ) + + for i, gn in enumerate(generated): + tokens = gn[0]["tokens"] + tokens_size = tokens.size(0) + padding_needed = max_lngth - tokens_size + tokens = torch.cat([tokens.new([bos_token]), tokens]) + tokens = F.pad(tokens, (0, padding_needed), value=self.dictionary.pad()) + n_src_tokens[i] = tokens + n_src_lengths[i] = tokens_size + 1 + + device = net_input["src_tokens"].device + # This seems to be important + del net_input["src_tokens"] + del net_input["src_lengths"] + net_input["src_tokens"] = n_src_tokens.to(device) + net_input["src_lengths"] = n_src_lengths.to(device) + + def generate(self, smp, model): + model.eval() + orig_lang = ( + self.dictionary[smp["net_input"]["src_tokens"][0][0]] + .replace(" ", "") + .replace("_", "") + ) + bos_token = smp["net_input"]["prev_output_tokens"][0][0] + with torch.no_grad(): + generated = self.sequence_generators[orig_lang].generate( + models=[model], sample=smp, bos_token=bos_token + ) + return generated + + def get_other_lang(self, lang): + # TODO: allow more complex mapping + if lang != self.mono_langs[0]: + return self.mono_langs[0] + if len(self.mono_langs) == 2: + return self.mono_langs[1] + return self.mono_langs[np.random.randint(1, len(self.mono_langs))] + + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): + + model.train() + model.set_num_updates(update_num) + + agg_loss, agg_sample_size = 0.0, 0.0 + agg_logging_output: Dict[str, float] = defaultdict(float) + + dataset_keys = self.datasets["train"].datasets.keys() + + weights = { + "BT": self.lambda_bt(update_num), + "DENOISE": self.lambda_dae(update_num), + } + log_keys = {"BT": "bt_", "DENOISE": "dae_"} + + for dataset_key in dataset_keys: + smp = sample[dataset_key] + mono_lang, task_subtype = dataset_key.split("-") + if weights[task_subtype] == 0: + continue + + if task_subtype == "BT": + with torch.autograd.profiler.record_function("backtranslation"): + model.eval() + # TODO: Could we translate to several language at once ? + # this would allow to share encoder_out and maximize GPU usage. + other_lang = self.get_other_lang(mono_lang) + self.backtranslate_sample(smp, mono_lang, other_lang) + self.display_samples_once_in_a_while(smp, mono_lang, other_lang) + model.train() + + # Like in FairseqTask.train_step + with torch.autograd.profiler.record_function("forward"): + loss, sample_size, logging_output = criterion(model, smp) + loss *= weights[task_subtype] + if ignore_grad: + loss *= 0 + with torch.autograd.profiler.record_function("backward"): + optimizer.backward(loss) + + agg_loss += loss.item() + agg_sample_size += sample_size + for k in logging_output: + agg_logging_output[log_keys[task_subtype] + k] += logging_output[k] + agg_logging_output[k] += logging_output[k] + + return agg_loss, agg_sample_size, agg_logging_output + + def get_bos_token_from_sample(self, sample): + net_input = sample["net_input"] + source_lang_token_id = torch.unique(net_input["src_tokens"][:, 0]).item() + source_lang_token = self.dictionary[source_lang_token_id].replace("_", "") + target_lang_token_id = _lang_token_index( + self.dictionary, self.get_other_lang(source_lang_token) + ) + + return target_lang_token_id + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + bt_sample_size = sum(x.get("bt_sample_size", 0) for x in logging_outputs) + if bt_sample_size: + bt_loss_sum = sum(x.get("bt_loss", 0) for x in logging_outputs) + bt_loss_sum *= 1 / bt_sample_size / math.log(2) + metrics.log_scalar("bt_loss", bt_loss_sum, bt_sample_size, round=3) + + bt_nll_loss_sum = sum(x.get("bt_nll_loss", 0) for x in logging_outputs) + bt_ntokens = sum(x.get("bt_ntokens", 0) for x in logging_outputs) + bt_nll_loss_sum *= 1 / bt_ntokens / math.log(2) + metrics.log_scalar("bt_nll_loss", bt_nll_loss_sum, bt_ntokens, round=3) + metrics.log_derived( + "bt_ppl", lambda meters: utils.get_perplexity(meters["bt_nll_loss"].avg) + ) + + dae_sample_size = sum(x.get("dae_sample_size", 0) for x in logging_outputs) + if dae_sample_size: + dae_loss_sum = sum(x.get("dae_loss", 0) for x in logging_outputs) + dae_loss_sum *= 1 / dae_sample_size / math.log(2) + metrics.log_scalar("dae_loss", dae_loss_sum, dae_sample_size, round=3) + + dae_nll_loss_sum = sum(x.get("dae_nll_loss", 0) for x in logging_outputs) + dae_ntokens = sum(x.get("dae_ntokens", 0) for x in logging_outputs) + dae_nll_loss_sum *= 1 / dae_ntokens / math.log(2) + metrics.log_scalar("dae_nll_loss", dae_nll_loss_sum, dae_ntokens, round=3) + metrics.log_derived( + "dae_ppl", + lambda meters: utils.get_perplexity(meters["dae_nll_loss"].avg), + ) + + +@torch.no_grad() +def extend_embedding( + emb: nn.Module, new_vocab_size: int, copy_from_token_id: int +) -> None: + old_emb_data = emb.weight.data + (old_vocab_size, dim) = old_emb_data.shape + assert new_vocab_size >= old_vocab_size + + if new_vocab_size > old_vocab_size: + emb.weight.data = torch.zeros((new_vocab_size, dim)) + emb.weight.data[:old_vocab_size, :] = old_emb_data + # initialize new embeddings + emb.weight.data[old_vocab_size:, :] = old_emb_data[copy_from_token_id] + if hasattr(emb, "num_embeddings"): + emb.num_embeddings = new_vocab_size + if hasattr(emb, "out_features"): + emb.out_features = new_vocab_size + + if getattr(emb, "bias", None) is None: + return + + # Fix the bias. + # Bias shape can be different from the previous vocab size + # if the weight matrix was shared and alread extended but not the bias. + (old_vocab_size,) = emb.bias.shape + assert new_vocab_size >= old_vocab_size + if new_vocab_size > old_vocab_size: + old_bias = emb.bias.data + new_bias = torch.zeros( + (new_vocab_size,), dtype=old_bias.dtype, device=old_bias.device + ) + new_bias[:old_vocab_size] = old_bias + emb.bias.data = new_bias + + +def add_secial_tokens_to_dict_and_model( + dictionary: "fairseq.data.Dictionary", + model: nn.Module, + mono_langs: Sequence[str], +) -> None: + embs = model.encoder.embed_tokens + vocab_size, embedding_dim = embs.weight.shape + + # The model may or may not have a '' embedding yet + assert ( + len(dictionary) <= vocab_size <= len(dictionary) + 1 + ), f"Dictionary len ({len(dictionary)}) doesn't match embs shape ({embs.weight.shape})" + # TODO: we should reuse the pretrained model dict which already has + dictionary.add_symbol("") + + for lang in mono_langs: + lang_token = _lang_token(lang) + dictionary.add_symbol(lang_token) + logger.info( + f"dictionary: {len(dictionary)} -> {vocab_size} tokens " + f"after adding {len(mono_langs)} lang tokens." + ) + + if len(dictionary) <= vocab_size: + return + + extend_embedding(embs, len(dictionary), dictionary.bos()) + dec_embs = model.decoder.embed_tokens + extend_embedding(dec_embs, len(dictionary), dictionary.bos()) + lm_head = model.decoder.output_projection + extend_embedding(lm_head, len(dictionary), dictionary.bos()) + assert lm_head.weight.shape == (len(dictionary), embedding_dim) + + +def _lang_token(lang: str) -> str: + return f"__{lang}__" + + +def _lang_token_index(dictionary, lang: str) -> int: + return dictionary.index(_lang_token(lang)) + + +@contextlib.contextmanager +def assert_weights_have_changed(model: nn.Module): + def checksum(model: nn.Module) -> float: + return sum(p.sum().item() for p in model.parameters()) + + initial_checksum = checksum(model) + yield model + final_checksum = checksum(model) + logger.info( + f"initial_checksum={initial_checksum} -> final_checksum={final_checksum}" + ) + assert initial_checksum != final_checksum, "Model hasn't changed !" diff --git a/fairseq/fairseq/tasks/semisupervised_translation.py b/fairseq/fairseq/tasks/semisupervised_translation.py new file mode 100644 index 0000000000000000000000000000000000000000..432b8a52ca122bca3e3f24a1fd493da33614e742 --- /dev/null +++ b/fairseq/fairseq/tasks/semisupervised_translation.py @@ -0,0 +1,485 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from collections import OrderedDict + +from fairseq import utils +from fairseq.data import ( + BacktranslationDataset, + IndexedCachedDataset, + IndexedDataset, + IndexedRawTextDataset, + LanguagePairDataset, + NoisingDataset, + RoundRobinZipDatasets, + data_utils, + indexed_dataset, +) +from fairseq.models import FairseqMultiModel +from fairseq.sequence_generator import SequenceGenerator + +from . import register_task +from .multilingual_translation import MultilingualTranslationTask + + +logger = logging.getLogger(__name__) + + +def _get_bt_dataset_key(lang_pair): + return "bt:" + lang_pair + + +def _get_denoising_dataset_key(lang_pair): + return "denoising:" + lang_pair + + +# ported from UnsupervisedMT +def parse_lambda_config(x): + """ + Parse the configuration of lambda coefficient (for scheduling). + x = "3" # lambda will be a constant equal to x + x = "0:1,1000:0" # lambda will start from 1 and linearly decrease + # to 0 during the first 1000 iterations + x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000 + # iterations, then will linearly increase to 1 until iteration 2000 + """ + split = x.split(",") + if len(split) == 1: + return float(x), None + else: + split = [s.split(os.pathsep) for s in split] + assert all(len(s) == 2 for s in split) + assert all(k.isdigit() for k, _ in split) + assert all( + int(split[i][0]) < int(split[i + 1][0]) for i in range(len(split) - 1) + ) + return float(split[0][1]), [(int(k), float(v)) for k, v in split] + + +@register_task("semisupervised_translation") +class SemisupervisedTranslationTask(MultilingualTranslationTask): + """A task for training multiple translation models simultaneously. + + We iterate round-robin over batches from multiple language pairs, ordered + according to the `--lang-pairs` argument. + + The training loop is roughly: + + for i in range(len(epoch)): + for lang_pair in args.lang_pairs: + batch = next_batch_for_lang_pair(lang_pair) + loss = criterion(model_for_lang_pair(lang_pair), batch) + loss.backward() + optimizer.step() + + In practice, `next_batch_for_lang_pair` is abstracted in a FairseqDataset + (e.g., `RoundRobinZipDatasets`) and `model_for_lang_pair` is a model that + implements the `FairseqMultiModel` interface. + + During inference it is required to specify a single `--source-lang` and + `--target-lang`, instead of `--lang-pairs`. + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + # fmt: off + MultilingualTranslationTask.add_args(parser) + parser.add_argument('--lambda-parallel-config', default="1.0", type=str, metavar='CONFIG', + help='cross-entropy reconstruction coefficient (parallel data). ' + 'use fixed weight during training if set to floating point number. ' + 'use piecewise linear function over number of updates to schedule the ' + 'weight with the format: w0:step0,w1:step1,...') + parser.add_argument('--lambda-denoising-config', default="0.0", type=str, metavar='CONFIG', + help='Cross-entropy reconstruction coefficient (denoising autoencoding)' + 'use fixed weight during training if set to floating point number. ' + 'use piecewise linear function over number of updates to schedule the ' + 'weight with the format: w0:step0,w1:step1,...') + parser.add_argument('--lambda-otf-bt-config', default="0.0", type=str, metavar='CONFIG', + help='cross-entropy reconstruction coefficient (on-the-fly back-translation parallel data)' + 'use fixed weight during training if set to floating point number. ' + 'use piecewise linear function over number of updates to schedule the ' + 'weight with the format: w0:step0,w1:step1,...') + parser.add_argument('--bt-max-len-a', default=1.1, type=float, metavar='N', + help='generate back-translated sequences of maximum length ax + b, where x is the ' + 'source length') + parser.add_argument('--bt-max-len-b', default=10.0, type=float, metavar='N', + help='generate back-translated sequences of maximum length ax + b, where x is the ' + 'source length') + parser.add_argument('--bt-beam-size', default=1, type=int, metavar='N', + help='beam size used in beam search of online back-translation') + parser.add_argument('--max-word-shuffle-distance', default=3.0, type=float, metavar='N', + help='maximum word shuffle distance for denoising autoencoding data generation') + parser.add_argument('--word-dropout-prob', default=0.1, type=float, metavar='N', + help='word dropout probability for denoising autoencoding data generation') + parser.add_argument('--word-blanking-prob', default=0.2, type=float, metavar='N', + help='word blanking probability for denoising autoencoding data generation') + # fmt: on + + def __init__(self, args, dicts, training): + super().__init__(args, dicts, training) + self.lambda_parallel, self.lambda_parallel_steps = parse_lambda_config( + args.lambda_parallel_config + ) + self.lambda_otf_bt, self.lambda_otf_bt_steps = parse_lambda_config( + args.lambda_otf_bt_config + ) + self.lambda_denoising, self.lambda_denoising_steps = parse_lambda_config( + args.lambda_denoising_config + ) + if self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None: + denoising_lang_pairs = [ + "%s-%s" % (tgt, tgt) + for tgt in {lang_pair.split("-")[1] for lang_pair in args.lang_pairs} + ] + self.model_lang_pairs = self.model_lang_pairs + denoising_lang_pairs + self.backtranslate_datasets = {} + self.backtranslators = {} + + @classmethod + def setup_task(cls, args, **kwargs): + dicts, training = MultilingualTranslationTask.prepare(args, **kwargs) + return cls(args, dicts, training) + + def load_dataset(self, split, epoch=1, **kwargs): + """Load a dataset split.""" + paths = utils.split_paths(self.args.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + + def split_exists(split, src, tgt, lang): + if src is not None: + filename = os.path.join( + data_path, "{}.{}-{}.{}".format(split, src, tgt, lang) + ) + else: + filename = os.path.join( + data_path, "{}.{}-None.{}".format(split, src, tgt) + ) + return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl) + + def load_indexed_dataset(path, dictionary): + return data_utils.load_indexed_dataset( + path, dictionary, self.args.dataset_impl + ) + + # load parallel datasets + src_datasets, tgt_datasets = {}, {} + if ( + self.lambda_parallel > 0.0 + or self.lambda_parallel_steps is not None + or not split.startswith("train") + ): + for lang_pair in self.lang_pairs: + src, tgt = lang_pair.split("-") + if split_exists(split, src, tgt, src): + prefix = os.path.join( + data_path, "{}.{}-{}.".format(split, src, tgt) + ) + elif split_exists(split, tgt, src, src): + prefix = os.path.join( + data_path, "{}.{}-{}.".format(split, tgt, src) + ) + else: + continue + src_datasets[lang_pair] = load_indexed_dataset( + prefix + src, self.dicts[src] + ) + tgt_datasets[lang_pair] = load_indexed_dataset( + prefix + tgt, self.dicts[tgt] + ) + logger.info( + "parallel-{} {} {} examples".format( + data_path, split, len(src_datasets[lang_pair]) + ) + ) + if len(src_datasets) == 0: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, data_path) + ) + + # back translation datasets + backtranslate_datasets = {} + if ( + self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None + ) and split.startswith("train"): + for lang_pair in self.lang_pairs: + src, tgt = lang_pair.split("-") + if not split_exists(split, tgt, None, tgt): + raise FileNotFoundError( + "Dataset not found: backtranslation {} ({})".format( + split, data_path + ) + ) + filename = os.path.join( + data_path, "{}.{}-None.{}".format(split, tgt, tgt) + ) + dataset = load_indexed_dataset(filename, self.dicts[tgt]) + lang_pair_dataset_tgt = LanguagePairDataset( + dataset, + dataset.sizes, + self.dicts[tgt], + left_pad_source=self.args.left_pad_source, + left_pad_target=self.args.left_pad_target, + ) + lang_pair_dataset = LanguagePairDataset( + dataset, + dataset.sizes, + src_dict=self.dicts[src], + tgt=dataset, + tgt_sizes=dataset.sizes, + tgt_dict=self.dicts[tgt], + left_pad_source=self.args.left_pad_source, + left_pad_target=self.args.left_pad_target, + ) + backtranslate_datasets[lang_pair] = BacktranslationDataset( + tgt_dataset=self.alter_dataset_langtok( + lang_pair_dataset_tgt, + src_eos=self.dicts[tgt].eos(), + src_lang=tgt, + tgt_lang=src, + ), + backtranslation_fn=self.backtranslators[lang_pair], + src_dict=self.dicts[src], + tgt_dict=self.dicts[tgt], + output_collater=self.alter_dataset_langtok( + lang_pair_dataset=lang_pair_dataset, + src_eos=self.dicts[src].eos(), + src_lang=src, + tgt_eos=self.dicts[tgt].eos(), + tgt_lang=tgt, + ).collater, + ) + logger.info( + "backtranslate-{}: {} {} {} examples".format( + tgt, + data_path, + split, + len(backtranslate_datasets[lang_pair]), + ) + ) + self.backtranslate_datasets[lang_pair] = backtranslate_datasets[ + lang_pair + ] + + # denoising autoencoder + noising_datasets = {} + if ( + self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None + ) and split.startswith("train"): + for lang_pair in self.lang_pairs: + _, tgt = lang_pair.split("-") + if not split_exists(split, tgt, None, tgt): + continue + filename = os.path.join( + data_path, "{}.{}-None.{}".format(split, tgt, tgt) + ) + tgt_dataset1 = load_indexed_dataset(filename, self.dicts[tgt]) + tgt_dataset2 = load_indexed_dataset(filename, self.dicts[tgt]) + noising_dataset = NoisingDataset( + tgt_dataset1, + self.dicts[tgt], + seed=1, + max_word_shuffle_distance=self.args.max_word_shuffle_distance, + word_dropout_prob=self.args.word_dropout_prob, + word_blanking_prob=self.args.word_blanking_prob, + ) + noising_datasets[lang_pair] = self.alter_dataset_langtok( + LanguagePairDataset( + noising_dataset, + tgt_dataset1.sizes, + self.dicts[tgt], + tgt_dataset2, + tgt_dataset2.sizes, + self.dicts[tgt], + left_pad_source=self.args.left_pad_source, + left_pad_target=self.args.left_pad_target, + ), + src_eos=self.dicts[tgt].eos(), + src_lang=tgt, + tgt_eos=self.dicts[tgt].eos(), + tgt_lang=tgt, + ) + logger.info( + "denoising-{}: {} {} {} examples".format( + tgt, + data_path, + split, + len(noising_datasets[lang_pair]), + ) + ) + + def language_pair_dataset(lang_pair): + src, tgt = lang_pair.split("-") + src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair] + return self.alter_dataset_langtok( + LanguagePairDataset( + src_dataset, + src_dataset.sizes, + self.dicts[src], + tgt_dataset, + tgt_dataset.sizes, + self.dicts[tgt], + left_pad_source=self.args.left_pad_source, + left_pad_target=self.args.left_pad_target, + ), + self.dicts[src].eos(), + src, + self.dicts[tgt].eos(), + tgt, + ) + + self.datasets[split] = RoundRobinZipDatasets( + OrderedDict( + [ + (lang_pair, language_pair_dataset(lang_pair)) + for lang_pair in src_datasets.keys() + ] + + [ + (_get_bt_dataset_key(lang_pair), dataset) + for lang_pair, dataset in backtranslate_datasets.items() + ] + + [ + (_get_denoising_dataset_key(lang_pair), dataset) + for lang_pair, dataset in noising_datasets.items() + ] + ), + eval_key=None + if self.training + else "%s-%s" % (self.args.source_lang, self.args.target_lang), + ) + + def build_model(self, args, from_checkpoint=False): + from fairseq import models + + model = models.build_model(args, self, from_checkpoint) + if not isinstance(model, FairseqMultiModel): + raise ValueError( + "SemisupervisedTranslationTask requires a FairseqMultiModel architecture" + ) + + # create SequenceGenerator for each model that has backtranslation dependency on it + self.sequence_generators = {} + if ( + self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None + ) and self.training: + for lang_pair in self.lang_pairs: + src, tgt = lang_pair.split("-") + key = "{}-{}".format(tgt, src) + self.sequence_generators[key] = SequenceGenerator( + [model.models[key]], + tgt_dict=self.dicts[src], + beam_size=args.bt_beam_size, + max_len_a=args.bt_max_len_a, + max_len_b=args.bt_max_len_b, + ) + decoder_lang_tok_idx = self.get_decoder_langtok(src) + + def backtranslate_fn( + sample, + model=model.models[key], + bos_token=decoder_lang_tok_idx, + sequence_generator=self.sequence_generators[key], + ): + return sequence_generator.generate( + [model], + sample, + bos_token=bos_token, + ) + + self.backtranslators[lang_pair] = backtranslate_fn + + return model + + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): + model.train() + + if update_num > 0: + self.update_step(update_num) + + agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, {} + + def forward_backward(model, samples, logging_output_key, weight): + nonlocal agg_loss, agg_sample_size, agg_logging_output + if samples is None or len(samples) == 0: + return + loss, sample_size, logging_output = criterion(model, samples) + if ignore_grad: + loss *= 0 + else: + loss *= weight + optimizer.backward(loss) + agg_loss += loss.detach().item() + # TODO make summing of the sample sizes configurable + agg_sample_size += sample_size + for k in logging_output: + agg_logging_output[k] += logging_output[k] + agg_logging_output[logging_output_key] += logging_output[k] + + if self.lambda_parallel > 0.0: + for lang_pair in self.lang_pairs: + forward_backward( + model.models[lang_pair], + sample[lang_pair], + lang_pair, + self.lambda_parallel, + ) + + if self.lambda_otf_bt > 0.0: + for lang_pair in self.lang_pairs: + sample_key = _get_bt_dataset_key(lang_pair) + forward_backward( + model.models[lang_pair], + sample[sample_key], + sample_key, + self.lambda_otf_bt, + ) + + if self.lambda_denoising > 0.0: + for lang_pair in self.lang_pairs: + _, tgt = lang_pair.split("-") + sample_key = _get_denoising_dataset_key(lang_pair) + forward_backward( + model.models["{0}-{0}".format(tgt)], + sample[sample_key], + sample_key, + self.lambda_denoising, + ) + + return agg_loss, agg_sample_size, agg_logging_output + + def update_step(self, num_updates): + def lambda_step_func(config, n_iter): + """ + Update a lambda value according to its schedule configuration. + """ + ranges = [ + i + for i in range(len(config) - 1) + if config[i][0] <= n_iter < config[i + 1][0] + ] + if len(ranges) == 0: + assert n_iter >= config[-1][0] + return config[-1][1] + assert len(ranges) == 1 + i = ranges[0] + x_a, y_a = config[i] + x_b, y_b = config[i + 1] + return y_a + (n_iter - x_a) * float(y_b - y_a) / float(x_b - x_a) + + if self.lambda_parallel_steps is not None: + self.lambda_parallel = lambda_step_func( + self.lambda_parallel_steps, num_updates + ) + if self.lambda_denoising_steps is not None: + self.lambda_denoising = lambda_step_func( + self.lambda_denoising_steps, num_updates + ) + if self.lambda_otf_bt_steps is not None: + self.lambda_otf_bt = lambda_step_func(self.lambda_otf_bt_steps, num_updates) diff --git a/fairseq/fairseq/tasks/sentence_prediction.py b/fairseq/fairseq/tasks/sentence_prediction.py new file mode 100644 index 0000000000000000000000000000000000000000..de80addaf20e902a04f251bb6d0e3712fc5439d9 --- /dev/null +++ b/fairseq/fairseq/tasks/sentence_prediction.py @@ -0,0 +1,303 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os + +import contextlib +from dataclasses import dataclass, field +from typing import Optional +from omegaconf import MISSING, II, open_dict, OmegaConf + +import numpy as np +from fairseq.data import ( + ConcatSentencesDataset, + Dictionary, + IdDataset, + NestedDictionaryDataset, + NumelDataset, + NumSamplesDataset, + OffsetTokensDataset, + PrependTokenDataset, + RawLabelDataset, + RightPadDataset, + RightPaddingMaskDataset, + RollDataset, + SortDataset, + StripTokenDataset, + data_utils, +) +from fairseq.data.shorten_dataset import maybe_shorten_dataset +from fairseq.tasks import FairseqDataclass, FairseqTask, register_task +from fairseq.dataclass import ChoiceEnum + + +logger = logging.getLogger(__name__) +SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"]) + + +@dataclass +class SentencePredictionConfig(FairseqDataclass): + data: str = field(default=MISSING, metadata={"help": "path to data directory"}) + num_classes: int = field( + default=-1, + metadata={"help": "number of classes or regression targets"}, + ) + init_token: Optional[int] = field( + default=None, + metadata={"help": "add token at the beginning of each batch item"}, + ) + separator_token: Optional[int] = field( + default=None, + metadata={"help": "add separator token between inputs"}, + ) + no_shuffle: bool = field( + default=False, + ) + shorten_method: SHORTEN_METHOD_CHOICES = field( + default="none", + metadata={ + "help": "if not none, shorten sequences that exceed tokens_per_sample" + }, + ) + shorten_data_split_list: str = field( + default="", + metadata={ + "help": "comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)' + }, + ) + add_prev_output_tokens: bool = field( + default=False, + metadata={ + "help": "add prev_output_tokens to sample, used for encoder-decoder arch" + }, + ) + max_positions: int = field( + default=512, + metadata={"help": "max tokens per example"}, + ) + + regression_target: bool = II("criterion.regression_target") + classification_head_name: str = II("criterion.classification_head_name") + seed: int = II("common.seed") + + d2v2_multi: bool = field( + default=False, + metadata={"help": "prepare dataset for data2vec_multi"}, + ) + + +@register_task("sentence_prediction", dataclass=SentencePredictionConfig) +class SentencePredictionTask(FairseqTask): + """ + Sentence (or sentence pair) prediction (classification or regression) task. + + Args: + dictionary (Dictionary): the dictionary for the input of the task + """ + + def __init__(self, cfg, data_dictionary, label_dictionary): + super().__init__(cfg) + self.dictionary = data_dictionary + self._label_dictionary = label_dictionary + + @classmethod + def load_dictionary(cls, filename): + """Load the dictionary from the filename + + Args: + filename (str): the filename + """ + dictionary = Dictionary.load(filename) + dictionary.add_symbol("") + return dictionary + + @classmethod + def setup_task(cls, cfg, **kwargs): + assert cfg.num_classes > 0, "Must set task.num_classes" + + # load data dictionary + data_dict = cls.load_dictionary( + os.path.join(cfg.data, "input0", "dict.txt"), + ) + logger.info("[input] dictionary: {} types".format(len(data_dict))) + + # load label dictionary + if not cfg.regression_target: + label_dict = cls.load_dictionary( + os.path.join(cfg.data, "label", "dict.txt"), + ) + logger.info("[label] dictionary: {} types".format(len(label_dict))) + else: + label_dict = data_dict + return cls(cfg, data_dict, label_dict) + + def load_dataset(self, split, combine=False, **kwargs): + """Load a given dataset split (e.g., train, valid, test).""" + + def get_path(key, split): + return os.path.join(self.cfg.data, key, split) + + def make_dataset(key, dictionary): + split_path = get_path(key, split) + + try: + dataset = data_utils.load_indexed_dataset( + split_path, + dictionary, + combine=combine, + ) + except Exception as e: + if "StorageException: [404] Path not found" in str(e): + logger.warning(f"dataset {e} not found") + dataset = None + else: + raise e + return dataset + + input0 = make_dataset("input0", self.source_dictionary) + assert input0 is not None, "could not find dataset: {}".format( + get_path("input0", split) + ) + input1 = make_dataset("input1", self.source_dictionary) + + if self.cfg.init_token is not None: + input0 = PrependTokenDataset(input0, self.cfg.init_token) + + if input1 is None: + src_tokens = input0 + else: + if self.cfg.separator_token is not None: + input1 = PrependTokenDataset(input1, self.cfg.separator_token) + + src_tokens = ConcatSentencesDataset(input0, input1) + + with data_utils.numpy_seed(self.cfg.seed): + shuffle = np.random.permutation(len(src_tokens)) + + src_tokens = maybe_shorten_dataset( + src_tokens, + split, + self.cfg.shorten_data_split_list, + self.cfg.shorten_method, + self.max_positions(), + self.cfg.seed, + ) + + if self.cfg.d2v2_multi: + net_input = { + "source": RightPadDataset( + src_tokens, + pad_idx=self.source_dictionary.pad(), + ), + "id": IdDataset(), + "padding_mask": RightPaddingMaskDataset(src_tokens), + } + else: + net_input = { + "src_tokens": RightPadDataset( + src_tokens, + pad_idx=self.source_dictionary.pad(), + ), + "src_lengths": NumelDataset(src_tokens, reduce=False), + } + if self.cfg.add_prev_output_tokens: + prev_tokens_dataset = RightPadDataset( + RollDataset(src_tokens, 1), + pad_idx=self.dictionary.pad(), + ) + net_input.update( + prev_output_tokens=prev_tokens_dataset, + ) + + dataset = { + "id": IdDataset(), + "net_input": net_input, + "nsentences": NumSamplesDataset(), + "ntokens": NumelDataset(src_tokens, reduce=True), + } + + if not self.cfg.regression_target: + label_dataset = make_dataset("label", self.label_dictionary) + if label_dataset is not None: + dataset.update( + target=OffsetTokensDataset( + StripTokenDataset( + label_dataset, + id_to_strip=self.label_dictionary.eos(), + ), + offset=-self.label_dictionary.nspecial, + ) + ) + else: + label_path = "{0}.label".format(get_path("label", split)) + if os.path.exists(label_path): + + def parse_regression_target(i, line): + values = line.split() + assert ( + len(values) == self.cfg.num_classes + ), f'expected num_classes={self.cfg.num_classes} regression target values on line {i}, found: "{line}"' + return [float(x) for x in values] + + with open(label_path) as h: + dataset.update( + target=RawLabelDataset( + [ + parse_regression_target(i, line.strip()) + for i, line in enumerate(h.readlines()) + ] + ) + ) + + nested_dataset = NestedDictionaryDataset( + dataset, + sizes=[src_tokens.sizes], + ) + + if self.cfg.no_shuffle: + dataset = nested_dataset + else: + dataset = SortDataset( + nested_dataset, + # shuffle + sort_order=[shuffle], + ) + + logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset))) + + self.datasets[split] = dataset + return self.datasets[split] + + def build_model(self, cfg, from_checkpoint=False): + from fairseq import models + + with open_dict(cfg) if OmegaConf.is_config(cfg) else contextlib.ExitStack(): + cfg.max_positions = self.cfg.max_positions + + model = models.build_model(cfg, self, from_checkpoint) + + model.register_classification_head( + self.cfg.classification_head_name, + num_classes=self.cfg.num_classes, + ) + + return model + + def max_positions(self): + return self.cfg.max_positions + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary + + @property + def label_dictionary(self): + return self._label_dictionary diff --git a/fairseq/fairseq/tasks/sentence_prediction_adapters.py b/fairseq/fairseq/tasks/sentence_prediction_adapters.py new file mode 100644 index 0000000000000000000000000000000000000000..afe556962621ba509d6a784d33c40e1c5406a6fb --- /dev/null +++ b/fairseq/fairseq/tasks/sentence_prediction_adapters.py @@ -0,0 +1,56 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import contextlib +from omegaconf import open_dict, OmegaConf + +from fairseq.tasks import register_task +from fairseq.tasks.sentence_prediction import ( + SentencePredictionTask, + SentencePredictionConfig, +) + + +logger = logging.getLogger(__name__) + + +@register_task("sentence_prediction_adapters", dataclass=SentencePredictionConfig) +class SentencePredictionAdapterTask(SentencePredictionTask): + def build_model(self, cfg): + from fairseq import models + + with open_dict(cfg) if OmegaConf.is_config(cfg) else contextlib.ExitStack(): + cfg.max_positions = self.cfg.max_positions + + model = models.build_model(cfg, self) + + model.register_classification_head( + self.cfg.classification_head_name, + num_classes=self.cfg.num_classes, + ) + + logger.info("Freezing Embedding Parameters") + for parameter in model.encoder.sentence_encoder.embed_positions.parameters(): + parameter.requires_grad = False + for ( + parameter + ) in model.encoder.sentence_encoder.layernorm_embedding.parameters(): + parameter.requires_grad = False + for parameter in model.encoder.sentence_encoder.embed_tokens.parameters(): + parameter.requires_grad = False + + logger.info("Freezing Adapters") + for k, v in model.encoder.sentence_encoder.layers._modules.items(): + logger.info("Freezing Adapters in Layer " + str(k)) + if hasattr(v, "adapter_layer_norm"): + logger.info("Freezing Adapter LN") + for parameter in v.adapter_layer_norm.parameters(): + parameter.requires_grad = False + for parameter in v.adapter_modules.parameters(): + parameter.requires_grad = False + + return model diff --git a/fairseq/fairseq/tasks/sentence_ranking.py b/fairseq/fairseq/tasks/sentence_ranking.py new file mode 100644 index 0000000000000000000000000000000000000000..57f63aab6725922d1b07b5cf67c45a44356f454f --- /dev/null +++ b/fairseq/fairseq/tasks/sentence_ranking.py @@ -0,0 +1,219 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os + +import numpy as np +from fairseq import utils +from fairseq.data import ( + ConcatSentencesDataset, + Dictionary, + IdDataset, + NestedDictionaryDataset, + NumelDataset, + NumSamplesDataset, + PrependTokenDataset, + RawLabelDataset, + RightPadDataset, + SortDataset, + TruncateDataset, + data_utils, +) +from fairseq.data.shorten_dataset import maybe_shorten_dataset +from fairseq.tasks import LegacyFairseqTask, register_task + + +logger = logging.getLogger(__name__) + + +@register_task("sentence_ranking") +class SentenceRankingTask(LegacyFairseqTask): + """ + Ranking task on multiple sentences. + + Args: + dictionary (Dictionary): the dictionary for the input of the task + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument("data", metavar="FILE", help="file prefix for data") + parser.add_argument( + "--num-classes", type=int, help="number of sentences to be ranked" + ) + parser.add_argument( + "--init-token", + type=int, + help="add token at the beginning of each batch item", + ) + parser.add_argument( + "--separator-token", type=int, help="add separator token between inputs" + ) + parser.add_argument("--no-shuffle", action="store_true") + parser.add_argument( + "--shorten-method", + default="none", + choices=["none", "truncate", "random_crop"], + help="if not none, shorten sequences that exceed --tokens-per-sample", + ) + parser.add_argument( + "--shorten-data-split-list", + default="", + help="comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)', + ) + parser.add_argument( + "--max-option-length", type=int, help="max length for each option" + ) + + def __init__(self, args, dictionary): + super().__init__(args) + self.dictionary = dictionary + + @classmethod + def load_dictionary(cls, args, filename, source=True): + """Load the dictionary from the filename + + Args: + filename (str): the filename + """ + dictionary = Dictionary.load(filename) + dictionary.add_symbol("") + return dictionary + + @classmethod + def setup_task(cls, args, **kwargs): + assert ( + args.criterion == "sentence_ranking" + ), "Must set --criterion=sentence_ranking" + + # load data dictionary + data_dict = cls.load_dictionary( + args, + os.path.join(args.data, "input0", "dict.txt"), + source=True, + ) + logger.info("[input] dictionary: {} types".format(len(data_dict))) + return SentenceRankingTask(args, data_dict) + + def load_dataset(self, split, combine=False, **kwargs): + """Load a given dataset split (e.g., train, valid, test).""" + + def get_path(type, split): + return os.path.join(self.args.data, type, split) + + def make_dataset(type, dictionary): + split_path = get_path(type, split) + + dataset = data_utils.load_indexed_dataset( + split_path, + self.source_dictionary, + self.args.dataset_impl, + combine=combine, + ) + return dataset + + input0 = make_dataset("input0", self.source_dictionary) + input_options = [ + make_dataset("input{idx}".format(idx=idx + 1), self.source_dictionary) + for idx in range(self.args.num_classes) + ] + + if self.args.separator_token is not None: + input0 = PrependTokenDataset(input0, self.args.separator_token) + + src_tokens = [] + for input_option in input_options: + if self.args.init_token is not None: + input_option = PrependTokenDataset(input_option, self.args.init_token) + if self.args.max_option_length is not None: + input_option = TruncateDataset( + input_option, self.args.max_option_length + ) + src_token = ConcatSentencesDataset(input_option, input0) + src_token = maybe_shorten_dataset( + src_token, + split, + self.args.shorten_data_split_list, + self.args.shorten_method, + self.args.max_positions, + self.args.seed, + ) + src_tokens.append(src_token) + + with data_utils.numpy_seed(self.args.seed): + shuffle = np.random.permutation(len(src_tokens[0])) + + dataset = { + "id": IdDataset(), + "nsentences": NumSamplesDataset(), + "ntokens": NumelDataset(src_tokens[0], reduce=True), + } + + for src_token_idx in range(len(src_tokens)): + dataset.update( + { + "net_input{idx}".format(idx=src_token_idx + 1): { + "src_tokens": RightPadDataset( + src_tokens[src_token_idx], + pad_idx=self.source_dictionary.pad(), + ), + "src_lengths": NumelDataset( + src_tokens[src_token_idx], reduce=False + ), + } + } + ) + + label_path = "{}.label".format(get_path("label", split)) + if os.path.exists(label_path): + with open(label_path) as h: + dataset.update( + target=RawLabelDataset([int(x.strip()) for x in h.readlines()]) + ) + + nested_dataset = NestedDictionaryDataset( + dataset, + sizes=[np.maximum.reduce([src_token.sizes for src_token in src_tokens])], + ) + + if self.args.no_shuffle: + dataset = nested_dataset + else: + dataset = SortDataset( + nested_dataset, + # shuffle + sort_order=[shuffle], + ) + + logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset))) + + self.datasets[split] = dataset + return self.datasets[split] + + def build_model(self, args, from_checkpoint=False): + from fairseq import models + + model = models.build_model(args, self, from_checkpoint) + + model.register_classification_head( + getattr(args, "ranking_head_name", "sentence_classification_head"), + num_classes=1, + ) + + return model + + def max_positions(self): + return self.args.max_positions + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary diff --git a/fairseq/fairseq/tasks/simultaneous_translation.py b/fairseq/fairseq/tasks/simultaneous_translation.py new file mode 100644 index 0000000000000000000000000000000000000000..9576b26801dcc9c1433b0e8632926117c0d50aea --- /dev/null +++ b/fairseq/fairseq/tasks/simultaneous_translation.py @@ -0,0 +1,41 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from fairseq.tasks import register_task +from fairseq.tasks.speech_to_text import SpeechToTextTask +from fairseq.tasks.translation import TranslationTask, TranslationConfig + +try: + import examples.simultaneous_translation # noqa + + import_successful = True +except BaseException: + import_successful = False + + +logger = logging.getLogger(__name__) + + +def check_import(flag): + if not flag: + raise ImportError( + "'examples.simultaneous_translation' is not correctly imported. " + "Please considering `pip install -e $FAIRSEQ_DIR`." + ) + + +@register_task("simul_speech_to_text") +class SimulSpeechToTextTask(SpeechToTextTask): + def __init__(self, args, tgt_dict): + check_import(import_successful) + super().__init__(args, tgt_dict) + + +@register_task("simul_text_to_text", dataclass=TranslationConfig) +class SimulTextToTextTask(TranslationTask): + def __init__(self, cfg, src_dict, tgt_dict): + check_import(import_successful) + super().__init__(cfg, src_dict, tgt_dict) diff --git a/fairseq/fairseq/tasks/span_masked_lm.py b/fairseq/fairseq/tasks/span_masked_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..d746aa154c815b339451e04131642dc9d419bb2a --- /dev/null +++ b/fairseq/fairseq/tasks/span_masked_lm.py @@ -0,0 +1,243 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from dataclasses import dataclass, field +from typing import Optional + +import numpy as np +from omegaconf import II, MISSING + +from fairseq import utils +from fairseq.data import ( + AppendTokenDataset, + Dictionary, + IdDataset, + NestedDictionaryDataset, + NumelDataset, + PadDataset, + PrependTokenDataset, + StripTokenDataset, + TokenBlockDataset, + data_utils, +) +from fairseq.data.shorten_dataset import maybe_shorten_dataset +from fairseq.data.span_mask_tokens_dataset import SpanMaskedTokensDataset +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.tasks import FairseqTask, register_task + +from ..data.indexed_dataset import get_available_dataset_impl + +logger = logging.getLogger(__name__) + +SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"]) +SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"]) + + +@dataclass +class SpanMaskedLMConfig(FairseqDataclass): + shuffle: bool = field( + default=False, + ) + noise_density: float = field( + default=0.15, + metadata={"help": "What fraction of the tokens to select as noise"}, + ) + mean_noise_span_length: float = field( + default=3, + metadata={"help": "Mean noise span length, must be >= 1"}, + ) + data: str = field( + default=MISSING, + metadata={ + "help": "colon separated path to data directories list, " + "will be iterated upon during epochs in round-robin manner" + }, + ) + sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field( + default="none", + metadata={ + "help": 'If omitted or "none", fills each sample with tokens-per-sample ' + 'tokens. If set to "complete", splits samples only at the end ' + "of sentence, but may include multiple sentences per sample. " + '"complete_doc" is similar but respects doc boundaries. ' + 'If set to "eos", includes only one sentence per sample.' + }, + ) + tokens_per_sample: int = field( + default=1024, + metadata={"help": "max number of tokens per sample for LM dataset"}, + ) + shorten_method: SHORTEN_METHOD_CHOICES = field( + default="none", + metadata={ + "help": "if not none, shorten sequences that exceed --tokens-per-sample" + }, + ) + shorten_data_split_list: str = field( + default="", + metadata={ + "help": "comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)' + }, + ) + seed: int = II("common.seed") + dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II( + "dataset.dataset_impl" + ) + max_source_positions: int = field( + default=1024, metadata={"help": "max number of tokens in the source sequence"} + ) + max_target_positions: int = field( + default=1024, metadata={"help": "max number of tokens in the target sequence"} + ) + include_target_tokens: bool = field( + default=False, + metadata={ + "help": "include target tokens in model input. this is used for data2vec" + }, + ) + + +@register_task("span_masked_lm", dataclass=SpanMaskedLMConfig) +class SpanMaskedLMTask(FairseqTask): + """ + Span masked language modeling task. (ie. T5) + """ + + cfg: SpanMaskedLMConfig + + def __init__(self, cfg, dictionary): + super().__init__(cfg) + self.dictionary = dictionary + + @classmethod + def setup_task(cls, cfg: SpanMaskedLMConfig, **kwargs): + """Setup the task.""" + paths = utils.split_paths(cfg.data) + assert len(paths) > 0 + dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) + logger.info("dictionary: {} types".format(len(dictionary))) + if not hasattr(cfg, "shuffle"): + cfg.shuffle = False + return cls(cfg, dictionary) + + def _load_dataset_split(self, split, epoch, combine): + paths = utils.split_paths(self.cfg.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + split_path = os.path.join(data_path, split) + + dataset = data_utils.load_indexed_dataset( + split_path, + self.dictionary, + self.cfg.dataset_impl, + combine=combine, + ) + if dataset is None: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, split_path) + ) + + dataset = StripTokenDataset(dataset, self.dictionary.eos()) + + dataset = maybe_shorten_dataset( + dataset, + split, + self.cfg.shorten_data_split_list, + self.cfg.shorten_method, + self.cfg.tokens_per_sample, + self.cfg.seed, + ) + + # create continuous blocks of tokens + dataset = TokenBlockDataset( + dataset, + dataset.sizes, + self.cfg.tokens_per_sample - 2, # one less for and one for + pad=self.dictionary.pad(), + eos=self.dictionary.eos(), + break_mode=self.cfg.sample_break_mode, + document_sep_len=0, + ) + logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) + + # prepend beginning-of-sentence token (, equiv. to [CLS] in BERT) + dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) + dataset = AppendTokenDataset(dataset, self.source_dictionary.eos()) + return dataset + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + dataset = self._load_dataset_split(split, epoch, combine) + + self.datasets[split] = SpanMaskedTokensDataset( + dataset, + self.dictionary, + noise_density=self.cfg.noise_density, + mean_noise_span_length=self.cfg.mean_noise_span_length, + shuffle=self.cfg.shuffle, + seed=self.cfg.seed, + ) + logger.info( + "Split: {0}, Loaded {1} samples of span_masked_tokens_dataset".format( + split, + len(self.datasets[split]), + ) + ) + + def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): + """ + Generate batches for inference. We assume that the input begins with a + bos symbol (``) and ends with an eos symbol (``). + """ + pad = self.source_dictionary.pad() + eos = self.source_dictionary.eos() + src_dataset = TokenBlockDataset( + src_tokens, + src_lengths, + block_size=self.cfg.tokens_per_sample - 2, # for and + pad=pad, + eos=eos, + break_mode=self.cfg.sample_break_mode, + document_sep_len=0, + ) + prev_output_tokens = PrependTokenDataset( + StripTokenDataset(src_dataset, eos), eos + ) + src_dataset = PadDataset(src_dataset, pad_idx=pad, left_pad=False) + return NestedDictionaryDataset( + { + "id": IdDataset(), + "net_input": { + "src_tokens": src_dataset, + "src_lengths": NumelDataset(src_dataset, reduce=False), + "prev_output_tokens": PadDataset( + prev_output_tokens, pad_idx=pad, left_pad=False + ), + }, + "target": src_dataset, + }, + sizes=[np.array(src_lengths)], + ) + + def max_positions(self): + """Return the max sentence length allowed by the task.""" + return (self.cfg.max_source_positions, self.cfg.max_target_positions) + + @property + def source_dictionary(self): + """Return the source :class:`~fairseq.data.Dictionary`.""" + return self.dictionary + + @property + def target_dictionary(self): + """Return the target :class:`~fairseq.data.Dictionary`.""" + return self.dictionary diff --git a/fairseq/fairseq/tasks/speech_dlm_task.py b/fairseq/fairseq/tasks/speech_dlm_task.py new file mode 100644 index 0000000000000000000000000000000000000000..340732b928122356ab2183d050b504da1773e91a --- /dev/null +++ b/fairseq/fairseq/tasks/speech_dlm_task.py @@ -0,0 +1,561 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from dataclasses import dataclass, field +from typing import Optional +from collections import OrderedDict + +import numpy as np +import torch +from fairseq import utils +from fairseq.data import ( + AppendTokenDataset, + Dictionary, + IdDataset, + LMContextWindowDataset, + MonolingualDataset, + NestedDictionaryDataset, + NumelDataset, + PadDataset, + PrependTokenDataset, + SpeechDLMDataset, + StripTokenDataset, + TokenBlockDataset, + TruncatedDictionary, + data_utils, +) +from fairseq.data.indexed_dataset import get_available_dataset_impl +from fairseq.data.shorten_dataset import maybe_shorten_dataset +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.tasks import LegacyFairseqTask, register_task +from omegaconf import II + + +SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"]) +SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"]) +logger = logging.getLogger(__name__) + + +@dataclass +class SpeechDLMConfig(FairseqDataclass): + data: Optional[str] = field( + default=None, metadata={"help": "path to data directory"} + ) + channels: Optional[str] = field( + default=None, + metadata={ + "help": 'comma-separated list of channels to load e.g., "unitA,unitB"' + "(default: load all possible channels in the data path)" + }, + ) + channel_weights: Optional[str] = field( + default=None, + metadata={ + "help": "comma-separated list of weights for different losses" + "(default: None, which means all losses are treated equally)" + }, + ) + sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field( + default="none", + metadata={ + "help": 'If omitted or "none", fills each sample with tokens-per-sample ' + 'tokens. If set to "complete", splits samples only at the end ' + "of sentence, but may include multiple sentences per sample. " + '"complete_doc" is similar but respects doc boundaries. ' + 'If set to "eos", includes only one sentence per sample.' + }, + ) + tokens_per_sample: int = field( + default=1024, + metadata={"help": "max number of tokens per sample for LM dataset"}, + ) + output_dictionary_size: int = field( + default=-1, metadata={"help": "limit the size of output dictionary"} + ) + # str type is a workaround to put **default=True** here + next_unit_prediction: str = field( + default="False", + metadata={ + "help": "Perform Next Unit Prediction, expected str input ('True' or 'False')" + }, + ) + edge_unit_prediction: str = field( + default="True", + metadata={ + "help": "Perform Edge Unit Prediction, expected str input ('True' or 'False')" + }, + ) + duration_prediction: str = field( + default="True", + metadata={ + "help": "Perform Duration Prediction, expected str input ('True' or 'False')" + }, + ) + delayed_duration_target: str = field( + default="True", + metadata={ + "help": "Perform Delayed Duration Prediction, expected str input ('True' or 'False')" + "(default: 'True')" + }, + ) + max_target_durations: Optional[int] = field( + default=256, + metadata={"help": "max duration considered (cut off to this value)"}, + ) + add_bos_token: bool = field( + default=False, metadata={"help": "prepend beginning of sentence token ()"} + ) + max_target_positions: Optional[int] = field( + default=None, metadata={"help": "max number of tokens in the target sequence"} + ) + shorten_method: SHORTEN_METHOD_CHOICES = field( + default="none", + metadata={ + "help": "if not none, shorten sequences that exceed --tokens-per-sample" + }, + ) + shorten_data_split_list: str = field( + default="", + metadata={ + "help": "comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)' + }, + ) + # TODO common vars below add to parent + seed: int = II("common.seed") + dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II( + "dataset.dataset_impl" + ) + data_buffer_size: int = II("dataset.data_buffer_size") + tpu: bool = II("common.tpu") + + +@register_task("speech_dlm_task", dataclass=SpeechDLMConfig) +class SpeechDLMTask(LegacyFairseqTask): + """Task for the SpeechDLM model as described in the paper: + https://arxiv.org/pdf/2203.16502.pdf + + It create a multi-channel dataset (SpeechDLMDataset) from multiple + dictionaries. + + Args: + dictionaries (Dict[str, ~fairseq.data.Dictionary]): the dictionaries for + each input channel of the SpeechDLM model + output_dictionaries (Dict[str, ~fairseq.data.Dictionary]): the dictionaries + for the output of each channel of the SpeechDLM model. In most cases it + will be the same as *dictionaries*. + targets (List[str]): list of the target types that the SpeechDLM model + should predict. Can be one of "next", "edge", "duration". + Defaults to "next". + + .. note:: + + The SpeechDLM task is only compatible with + :mod:`fairseq-train` and :mod:`fairseq-validate`. + To generate new samples, please refer to example codes + at examples/textless_nlp/dgslm . + """ + + def __init__(self, args, dicts, output_dicts=None, targets=None): + super().__init__(args) + self.dicts = dicts + self.output_dicts = output_dicts or dicts + + if targets is None: + targets = ["next"] + self.targets = targets + + self.channels = list(dicts.keys()) + + if args.channel_weights is not None: + self.channel_weights = [float(w) for w in args.channel_weights.split(",")] + else: + self.channel_weights = [1.0 for _ in self.channels] + assert len(self.channel_weights) == len( + self.channels + ), "number of channel_weights must be the same as number of channels" + + assert str(args.next_unit_prediction).lower() in [ + "true", + "false", + ], f"Expected to be a string of boolean, found {args.next_unit_prediction}" + assert str(args.edge_unit_prediction).lower() in [ + "true", + "false", + ], f"Expected to be a string of boolean, found {args.edge_unit_prediction}" + assert str(args.duration_prediction).lower() in [ + "true", + "false", + ], f"Expected to be a string of boolean, found {args.duration_prediction}" + assert str(args.delayed_duration_target).lower() in [ + "true", + "false", + ], f"Expected to be a string of boolean, found {args.delayed_duration_target}" + self.next_unit_prediction = bool( + str(args.next_unit_prediction).lower() == "true" + ) + self.edge_unit_prediction = bool( + str(args.edge_unit_prediction).lower() == "true" + ) + self.duration_prediction = bool(str(args.duration_prediction).lower() == "true") + self.delayed_duration_target = bool( + str(args.delayed_duration_target).lower() == "true" + ) + + self.max_target_durations = args.max_target_durations + + @classmethod + def setup_dictionary(cls, args, **kwargs): + """The dictionaries will be a dict over channel keys and values of type + ~fairseq.data.Dictionary. + """ + paths = utils.split_paths(args.data) + assert len(paths) > 0 + data_path = paths[0] + + dicts = None + output_dicts = None + if args.channels is None: + sorted_channels = sorted( + name[5:-4] + for name in os.listdir(data_path) + if name[:5] == "dict." and name[-4:] == ".txt" + ) + else: + sorted_channels = sorted(args.channels.split(",")) + logger.info("channels: {}".format(sorted_channels)) + # load dictionaries + dicts = OrderedDict() + output_dicts = OrderedDict() + for channel in sorted_channels: + dictionary = Dictionary.load( + os.path.join(data_path, "dict.{}.txt".format(channel)) + ) + logger.info("[{}] dictionary: {} types".format(channel, len(dictionary))) + output_dictionary = dictionary + if args.output_dictionary_size >= 0: + output_dictionary = TruncatedDictionary( + dictionary, args.output_dictionary_size + ) + dicts[channel] = dictionary + output_dicts[channel] = output_dictionary + if len(dicts) > 0: + assert dicts[channel].pad() == dicts[sorted_channels[0]].pad() + assert dicts[channel].bos() == dicts[sorted_channels[0]].bos() + assert dicts[channel].eos() == dicts[sorted_channels[0]].eos() + assert dicts[channel].unk() == dicts[sorted_channels[0]].unk() + return (dicts, output_dicts) + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + dicts, output_dicts = cls.setup_dictionary(args, **kwargs) + + targets = [] + if str(getattr(args, "next_unit_prediction", "false")).lower() == "true": + targets.append("next") + if str(getattr(args, "edge_unit_prediction", "false")).lower() == "true": + targets.append("edge") + if str(getattr(args, "duration_prediction", "false")).lower() == "true": + targets.append("duration") + if len(targets) == 0: + # standard language modeling + targets = ["next"] + + return cls(args, dicts, output_dicts, targets=targets) + + def build_model(self, args): + model = super().build_model(args) + for target in self.targets: + if target not in model.supported_targets: + raise ValueError("Unsupported SpeechDLM target: {}".format(target)) + return model + + def load_dataset( + self, split: str, epoch=1, combine=False, **kwargs + ) -> SpeechDLMDataset: + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + paths = utils.split_paths(self.args.data) + assert len(paths) > 0 + + data_path = paths[(epoch - 1) % len(paths)] + + channel_datasets = {} + for channel in self.channels: + split_path = os.path.join(data_path, split + "." + channel) + dictionary = self.dicts[channel] + output_dictionary = self.output_dicts[channel] + + dataset = data_utils.load_indexed_dataset( + split_path, dictionary, self.args.dataset_impl, combine=combine + ) + + if dataset is None: + raise FileNotFoundError( + "[{}] Dataset not found: {} ({})".format(channel, split, split_path) + ) + + dataset = maybe_shorten_dataset( + dataset, + split, + self.args.shorten_data_split_list, + self.args.shorten_method, + self.args.tokens_per_sample, + self.args.seed, + ) + + dataset = TokenBlockDataset( + dataset, + dataset.sizes, + self.args.tokens_per_sample, + pad=dictionary.pad(), + eos=dictionary.eos(), + break_mode=self.args.sample_break_mode, + include_targets=True, + ) + + add_eos_for_other_targets = ( + self.args.sample_break_mode is not None + and self.args.sample_break_mode != "none" + ) + + channel_datasets[channel] = MonolingualDataset( + dataset=dataset, + sizes=dataset.sizes, + src_vocab=dictionary, + tgt_vocab=output_dictionary, + add_eos_for_other_targets=add_eos_for_other_targets, + shuffle=False, + targets=["future"], + add_bos_token=self.args.add_bos_token, + ) + + self.datasets[split] = SpeechDLMDataset( + datasets=channel_datasets, + targets=self.targets, + max_target_durations=self.max_target_durations, + shuffle=True, + ) + + def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): + """ + Generate batches for inference. We prepend an eos token to src_tokens + (or bos if `--add-bos-token` is set) and we append a to target. + This is convenient both for generation with a prefix and LM scoring. + """ + src_datasets = {} + tgt_datasets = {} + for channel in src_tokens[0]: + dataset = StripTokenDataset( + TokenBlockDataset( + [src_tokens[i][channel] for i in range(len(src_tokens))], + src_lengths, + block_size=None, # ignored for "eos" break mode + pad=self.source_dictionaries[channel].pad(), + eos=self.source_dictionaries[channel].eos(), + break_mode="eos", + ), + # remove eos from (end of) target sequence + self.source_dictionaries[channel].eos(), + ) + src_dataset = PrependTokenDataset( + dataset, + token=( + self.source_dictionaries[channel].bos() + if getattr(self.args, "add_bos_token", False) + else self.source_dictionaries[channel].eos() + ), + ) + tgt_dataset = AppendTokenDataset( + dataset, token=self.source_dictionaries[channel].pad() + ) + + src_datasets[channel] = src_dataset + tgt_datasets[channel] = tgt_dataset + + return NestedDictionaryDataset( + { + "id": IdDataset(), + "net_input": { + "src_tokens": OrderedDict( + [ + ( + channel, + PadDataset( + src_datasets[channel], + pad_idx=self.source_dictionaries[channel].pad(), + left_pad=False, + ), + ) + for channel in src_datasets + ] + ), + "src_lengths": NumelDataset( + next(iter(src_datasets.values())), reduce=False + ), + }, + "target": OrderedDict( + [ + ( + channel, + PadDataset( + tgt_datasets[channel], + pad_idx=self.source_dictionaries[channel].pad(), + left_pad=False, + ), + ) + for channel in tgt_datasets + ] + ), + }, + sizes=[np.array(src_lengths)], + ) + + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): + with torch.no_grad(): + # Generation will always be conditioned on bos_token + if getattr(self.args, "add_bos_token", False): + bos_token = self.source_dictionary.bos() + else: + bos_token = self.source_dictionary.eos() + + if constraints is not None: + raise NotImplementedError( + "Constrained decoding with the SpeechDLM task is not supported" + ) + # SequenceGenerator doesn't use src_tokens directly, we need to + # pass the `prefix_tokens` argument instead + if prefix_tokens is None: + prefix_tokens = {} + for channel in sample["net_input"]["src_tokens"]: + if sample["net_input"]["src_tokens"][channel].nelement(): + prefix_tokens_channel = sample["net_input"]["src_tokens"][ + channel + ] + if prefix_tokens_channel[:, 0].eq(bos_token).all(): + prefix_tokens_channel = prefix_tokens_channel[:, 1:] + prefix_tokens[channel] = prefix_tokens_channel + else: + prefix_tokens = None + break + return generator.generate( + models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token + ) + + def eval_lm_dataloader( + self, + dataset, + max_tokens: Optional[int] = 36000, + batch_size: Optional[int] = None, + max_positions: Optional[int] = None, + num_shards: int = 1, + shard_id: int = 0, + num_workers: int = 1, + data_buffer_size: int = 10, + # ensures that every evaluated token has access to a context of at least + # this size, if possible + context_window: int = 0, + ): + if context_window > 0: + dataset = LMContextWindowDataset( + dataset=dataset, + tokens_per_sample=self.args.tokens_per_sample, + context_window=context_window, + pad_idx=self.source_dictionary.pad(), + ) + return self.get_batch_iterator( + dataset=dataset, + max_tokens=max_tokens, + max_sentences=batch_size, + max_positions=max_positions, + ignore_invalid_inputs=True, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + data_buffer_size=data_buffer_size, + ).next_epoch_itr(shuffle=False) + + @property + def source_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self.dicts[self.channels[0]] + + @property + def target_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self.output_dicts[self.channels[0]] + + @property + def source_dictionaries(self): + """Return the dict of :class:`~fairseq.data.Dictionary` for the + multichannel language model.""" + return self.dicts + + @property + def target_dictionaries(self): + """Return the dict of :class:`~fairseq.data.Dictionary` for the + multichannel language model.""" + return self.output_dicts + + def build_generator(self, models, args, extra_gen_cls_kwargs=None): + + from fairseq.models.speech_dlm.sequence_generator import ( + multichannel_search, + MultichannelSequenceGenerator, + ) + + # Choose search strategy. Defaults to Beam Search. + sampling = getattr(args, "sampling", False) + sampling_topk = getattr(args, "sampling_topk", -1) + sampling_topp = getattr(args, "sampling_topp", -1.0) + assert ( + sampling_topk < 0 or sampling + ), "--sampling-topk requires sampling (not beam search)" + assert ( + sampling_topp < 0 or sampling + ), "--sampling-topp requires sampling (not beam search)" + + if sampling: + search_strategy = multichannel_search.ContiguousMultichannelSampling( + self.target_dictionaries, sampling_topk, sampling_topp + ) + else: + search_strategy = multichannel_search.ContiguousMultichannelBeamSearch( + self.target_dictionaries + ) + + extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} + + return MultichannelSequenceGenerator( + models, + self.target_dictionaries, + beam_size=getattr(args, "beam", 5), + max_len_a=getattr(args, "max_len_a", 0), + max_len_b=getattr(args, "max_len_b", 500), + min_len=getattr(args, "min_len", 1), + normalize_scores=(not getattr(args, "unnormalized", False)), + len_penalty=getattr(args, "lenpen", 1), + unk_penalty=getattr(args, "unkpen", 0), + temperature=getattr(args, "temperature", 1.0), + match_source_len=getattr(args, "match_source_len", False), + no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), + search_strategy=search_strategy, + duration_temperature=getattr(args, "duration_temperature", 1.0), + **extra_gen_cls_kwargs, + ) diff --git a/fairseq/fairseq/tasks/speech_to_speech.py b/fairseq/fairseq/tasks/speech_to_speech.py new file mode 100644 index 0000000000000000000000000000000000000000..5aaaa95a90cce0ae3d4bc8cf9b79b312dd342b3f --- /dev/null +++ b/fairseq/fairseq/tasks/speech_to_speech.py @@ -0,0 +1,597 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import logging +import math +from argparse import Namespace +from pathlib import Path +from typing import List + +import torch +import torch.nn as nn + +from fairseq import utils +from fairseq.data import Dictionary +from fairseq.data.audio.data_cfg import MultitaskConfig, S2SDataConfig +from fairseq.data.audio.speech_to_speech_dataset import SpeechToSpeechDatasetCreator +from fairseq.data.audio.speech_to_text_dataset import ( + SpeechToTextDataset, + TextTargetMultitaskData, +) +from fairseq.tasks import LegacyFairseqTask, register_task +from fairseq.tasks.speech_to_text import DummyMultiTask +from fairseq.tasks.text_to_speech import batch_mel_cepstral_distortion + +logger = logging.getLogger(__name__) + + +class StackUnitSequenceGenerator(nn.Module): + def __init__(self, tgt_dict, vocab_size): + super().__init__() + self.pad = tgt_dict.pad() + self.eos = tgt_dict.eos() + self.unk = tgt_dict.unk() + self.offset = len(tgt_dict) - vocab_size + self.vocab_size = vocab_size + + def pack_units(self, input: torch.Tensor, n_frames_per_step) -> torch.Tensor: + if n_frames_per_step <= 1: + return input + + bsz, _, n = input.shape + assert n == n_frames_per_step + + scale = [ + pow(self.vocab_size, n_frames_per_step - 1 - i) + for i in range(n_frames_per_step) + ] + scale = torch.LongTensor(scale).squeeze(0).to(input.device) + mask = input >= self.offset + res = ((input - self.offset) * scale * mask).sum(dim=2) + self.offset + return res + + @torch.no_grad() + def generate(self, models, sample, **kwargs): + # currently only support viterbi search for stacked units + model = models[0] + model.eval() + + max_len = model.max_decoder_positions() + # TODO: incorporate max_len_a and max_len_b + + src_tokens = sample["net_input"]["src_tokens"] + src_lengths = sample["net_input"]["src_lengths"] + bsz, src_len, _ = src_tokens.size() + n_frames_per_step = model.decoder.n_frames_per_step + + # initialize + encoder_out = model.forward_encoder( + src_tokens, src_lengths, speaker=sample["speaker"] + ) + incremental_state = {} + pred_out, attn, scores = [], [], [] + finished = src_tokens.new_zeros((bsz,)).bool() + + prev_output_tokens = src_lengths.new_zeros((bsz, 1)).long().fill_(self.eos) + for _ in range(max_len): + cur_out, cur_extra = model.forward_decoder( + prev_output_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, + ) + + lprobs = model.get_normalized_probs([cur_out], log_probs=True) + # never select pad, unk + lprobs[:, :, self.pad] = -math.inf + lprobs[:, :, self.unk] = -math.inf + + cur_pred_lprob, cur_pred_out = torch.max(lprobs, dim=2) + scores.append(cur_pred_lprob) + pred_out.append(cur_pred_out) + + prev_output_tokens = torch.cat( + ( + prev_output_tokens, + self.pack_units( + cur_pred_out.view(bsz, 1, n_frames_per_step), n_frames_per_step + ), + ), + dim=1, + ) + + attn.append(cur_extra["attn"][0]) + + cur_finished = torch.any(cur_pred_out.squeeze(1) == self.eos, dim=1) + finished = finished | cur_finished + if finished.sum().item() == bsz: + break + + pred_out = torch.cat(pred_out, dim=1).view(bsz, -1) + attn = torch.cat(attn, dim=2) + alignment = attn.max(dim=1)[1] + attn = attn.repeat_interleave(n_frames_per_step, dim=2) + alignment = alignment.repeat_interleave(n_frames_per_step, dim=1) + scores = torch.cat(scores, dim=1) + eos_idx = (pred_out == self.eos).nonzero(as_tuple=True) + out_lens = src_lengths.new_zeros((bsz,)).long().fill_(max_len) + for b, l in zip(eos_idx[0], eos_idx[1]): + out_lens[b] = min(l, out_lens[b]) + + hypos = [ + [ + { + "tokens": pred_out[b, :out_len], + "attn": attn[b, :, :out_len], + "alignment": alignment[b, :out_len], + "positional_scores": scores[b, :out_len], + "score": utils.item(scores[b, :out_len].sum().data), + } + ] + for b, out_len in zip(range(bsz), out_lens) + ] + + return hypos + + +@register_task("speech_to_speech") +class SpeechToSpeechTask(LegacyFairseqTask): + @classmethod + def add_args(cls, parser): + parser.add_argument("data", help="manifest root path") + parser.add_argument( + "--config-yaml", + type=str, + default="config.yaml", + help="Configuration YAML filename (under manifest root)", + ) + parser.add_argument( + "--multitask-config-yaml", + type=str, + default=None, + help="Configuration YAML filename for the multitasks (under manifest root)", + ) + parser.add_argument( + "--max-source-positions", + default=6000, + type=int, + metavar="N", + help="max number of tokens in the source sequence", + ) + parser.add_argument( + "--max-target-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the target sequence", + ) + parser.add_argument( + "--target-is-code", + action="store_true", + help="set if target is discrete unit instead of spectrogram", + ) + parser.add_argument( + "--target-code-size", type=int, default=None, help="# discrete units" + ) + parser.add_argument( + "--n-frames-per-step", + type=int, + default=1, + help="# stacked frames, use 0 for reduced discrete unit sequence", + ) + parser.add_argument("--eval-inference", action="store_true") + parser.add_argument( + "--eval-args", + type=str, + default="{}", + help='generation args for speech-to-unit model , e.g., \'{"beam": 5, "max_len_a": 1}\', as JSON string', + ) + parser.add_argument("--eos-prob-threshold", type=float, default=0.5) + parser.add_argument( + "--mcd-normalize-type", + type=str, + default="targ", + choices=["targ", "pred", "path"], + ) + parser.add_argument( + "--vocoder", + type=str, + default="griffin_lim", + choices=["griffin_lim", "hifigan", "code_hifigan"], + ) + parser.add_argument("--spec-bwd-max-iter", type=int, default=8) + parser.add_argument( + "--infer-target-lang", + type=str, + default="", + help="target language for inference", + ) + + def __init__(self, args, tgt_dict, infer_tgt_lang_id=None): + super().__init__(args) + self.tgt_dict = tgt_dict + self.data_cfg = S2SDataConfig(Path(args.data) / args.config_yaml) + + self.multitask_tasks = {} + self.tgt_dict_mt = None + self.eos_token_mt = None + if getattr(args, "multitask_config_yaml", None) is not None: + multitask_cfg = MultitaskConfig( + Path(args.data) / args.multitask_config_yaml + ) + first_pass_task_idx = multitask_cfg.first_pass_decoder_task_index + for i, (task_name, task_config) in enumerate( + multitask_cfg.get_all_tasks().items() + ): + task_obj = DummyMultiTask( + task_config, + task_config.tgt_dict, + first_pass=i == first_pass_task_idx, + ) + self.multitask_tasks[task_name] = task_obj + if task_obj.is_first_pass_decoder: + self.tgt_dict_mt = task_obj.target_dictionary + if task_config.prepend_bos_and_append_tgt_lang_tag: + self.eos_token_mt = task_config.eos_token + assert not isinstance(self.eos_token_mt, List) + + if not self.eos_token_mt: + raise Warning( + "Please provide eos_token in --multitask-config-yaml to replace eos in sequence generator" + ) + + self._infer_tgt_lang_id = infer_tgt_lang_id + + @classmethod + def setup_task(cls, args, **kwargs): + data_cfg = data_cfg = S2SDataConfig(Path(args.data) / args.config_yaml) + tgt_dict = None + infer_tgt_lang_id = None + if args.target_is_code: + if data_cfg.prepend_tgt_lang_tag_as_bos: + # dictionary with language tags + dict_path = Path(args.data) / data_cfg.vocab_filename + if not dict_path.is_file(): + raise FileNotFoundError( + f"Dict has to be provided when setting prepend_tgt_lang_tag_as_bos: true, but dict not found: {dict_path}" + ) + tgt_dict = Dictionary.load(dict_path.as_posix()) + + # target langauge for inference + if args.infer_target_lang != "": + tgt_lang_tag = SpeechToTextDataset.LANG_TAG_TEMPLATE.format( + args.infer_target_lang + ) + infer_tgt_lang_id = tgt_dict.index(tgt_lang_tag) + assert infer_tgt_lang_id != tgt_dict.unk() + else: + assert args.target_code_size is not None + + tgt_dict = Dictionary() + for i in range(args.target_code_size): + tgt_dict.add_symbol(str(i)) + logger.info(f"dictionary size: " f"{len(tgt_dict):,}") + + if getattr(args, "train_subset", None) is not None: + if not all(s.startswith("train") for s in args.train_subset.split(",")): + raise ValueError('Train splits should be named like "train*".') + + assert args.n_frames_per_step >= 1 + assert ( + not args.eval_inference + or (args.target_is_code and args.vocoder == "code_hifigan") + or (not args.target_is_code and args.vocoder != "code_hifigan") + ) + + return cls(args, tgt_dict, infer_tgt_lang_id=infer_tgt_lang_id) + + def build_criterion(self, args): + from fairseq import criterions + + if len(self.multitask_tasks) > 0: + if self.args.target_is_code and not args._name.startswith("speech_to_unit"): + raise ValueError( + "set --criterion speech_to_unit for speech-to-unit loss with multitask" + ) + elif not self.args.target_is_code and not args._name.startswith( + "speech_to_spectrogram" + ): + raise ValueError( + "set --criterion speech_to_spectrogram for speech-to-spectrogram loss with multitask" + ) + + return criterions.build_criterion(args, self) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + self.datasets[split] = SpeechToSpeechDatasetCreator.from_tsv( + root=self.args.data, + data_cfg=self.data_cfg, + splits=split, + is_train_split=split.startswith("train"), + epoch=epoch, + seed=self.args.seed, + target_is_code=self.args.target_is_code, + tgt_dict=self.target_dictionary, + n_frames_per_step=self.args.n_frames_per_step, + multitask=self.multitask_tasks, + ) + + @property + def target_dictionary(self): + return self.tgt_dict + + @property + def target_dictionary_mt(self): + return self.tgt_dict_mt + + @property + def source_dictionary(self): + return None + + def max_positions(self): + return self.args.max_source_positions, self.args.max_target_positions + + def build_model(self, args, from_checkpoint=False): + args.input_feat_per_channel = self.data_cfg.input_feat_per_channel + args.input_channels = self.data_cfg.input_transformed_channels + args.target_speaker_embed = self.data_cfg.target_speaker_embed is not None + args.n_frames_per_step = self.args.n_frames_per_step + + model = super().build_model(args, from_checkpoint) + + if len(self.multitask_tasks) > 0: + from fairseq.models.speech_to_speech.s2s_transformer import ( + S2STransformerMultitaskModelBase, + ) + + assert isinstance(model, S2STransformerMultitaskModelBase) + + if self.args.eval_inference: + self.eval_gen_args = json.loads(self.args.eval_args) + self.generator = self.build_generator( + [model], Namespace(**self.eval_gen_args) + ) + + return model + + def build_generator_dual_decoder( + self, + models, + args, + extra_gen_cls_kwargs=None, + ): + from examples.speech_to_speech.unity.sequence_generator_multi_decoder import ( + MultiDecoderSequenceGenerator, + ) + + return MultiDecoderSequenceGenerator( + models, + self.target_dictionary, + self.target_dictionary_mt, + beam_size=max(1, getattr(args, "beam", 1)), + beam_size_mt=max(1, getattr(args, "beam_mt", 1)), + max_len_a=getattr(args, "max_len_a", 0), + max_len_b=getattr(args, "max_len_b", 200), + max_len_a_mt=getattr(args, "max_len_a_mt", 0), + max_len_b_mt=getattr(args, "max_len_b_mt", 200), + min_len=getattr(args, "min_len", 1), + normalize_scores=(not getattr(args, "unnormalized", False)), + len_penalty=getattr(args, "lenpen", 1), + unk_penalty=getattr(args, "unkpen", 0), + temperature=getattr(args, "temperature", 1.0), + match_source_len=getattr(args, "match_source_len", False), + no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), + **extra_gen_cls_kwargs, + ) + + def build_generator( + self, + models, + args, + seq_gen_cls=None, + extra_gen_cls_kwargs=None, + ): + + if not self.args.target_is_code or self.args.eval_inference: + from fairseq.models.text_to_speech.vocoder import get_vocoder + + self.vocoder = get_vocoder(self.args, self.data_cfg) + self.vocoder = ( + self.vocoder.cuda() + if torch.cuda.is_available() and not self.args.cpu + else self.vocoder.cpu() + ) + + has_dual_decoder = getattr(models[0], "mt_task_name", None) is not None + + if self.args.target_is_code: + if self.args.n_frames_per_step == 1: + if has_dual_decoder: + seq_generator = self.build_generator_dual_decoder( + models, + args, + extra_gen_cls_kwargs=extra_gen_cls_kwargs, + ) + else: + seq_generator = super().build_generator( + models, + args, + seq_gen_cls=None, + extra_gen_cls_kwargs=extra_gen_cls_kwargs, + ) + else: + assert ( + getattr(args, "beam", 1) == 1 and getattr(args, "nbest", 1) == 1 + ), "only support viterbi search for stacked units" + seq_generator = StackUnitSequenceGenerator( + self.tgt_dict, + self.args.target_code_size, + ) + else: + if has_dual_decoder: + if getattr(args, "teacher_forcing", False): + raise NotImplementedError + else: + from fairseq.speech_generator import MultiDecoderSpeechGenerator + + generator = MultiDecoderSpeechGenerator + + lang_token_ids_aux = { + i + for s, i in self.tgt_dict_mt.indices.items() + if TextTargetMultitaskData.is_lang_tag(s) + } + + if extra_gen_cls_kwargs is None: + extra_gen_cls_kwargs = {} + extra_gen_cls_kwargs[ + "symbols_to_strip_from_output" + ] = lang_token_ids_aux + + eos_id_mt = ( + self.tgt_dict_mt.index(self.eos_token_mt) + if self.eos_token_mt + else None + ) + assert eos_id_mt != self.tgt_dict_mt.unk() + extra_gen_cls_kwargs["eos_mt"] = eos_id_mt + + seq_generator = generator( + models, + args, + self.vocoder, + self.data_cfg, + self.target_dictionary_mt, + max_iter=self.args.max_target_positions, + eos_prob_threshold=self.args.eos_prob_threshold, + **extra_gen_cls_kwargs, + ) + else: + if getattr(args, "teacher_forcing", False): + from fairseq.speech_generator import ( + TeacherForcingAutoRegressiveSpeechGenerator, + ) + + generator = TeacherForcingAutoRegressiveSpeechGenerator + logger.info("Teacher forcing mode for generation") + else: + from fairseq.speech_generator import AutoRegressiveSpeechGenerator + + generator = AutoRegressiveSpeechGenerator + + seq_generator = generator( + models[0], + self.vocoder, + self.data_cfg, + max_iter=self.args.max_target_positions, + eos_prob_threshold=self.args.eos_prob_threshold, + ) + + return seq_generator + + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): + for task_name, task_obj in self.multitask_tasks.items(): + criterion.set_multitask_loss_weight( + task_name, task_obj.args.get_loss_weight(update_num) + ) + if task_name in model.multitask_decoders: + model.multitask_decoders[task_name].train() + + loss, sample_size, logging_output = super().train_step( + sample, model, criterion, optimizer, update_num, ignore_grad + ) + return loss, sample_size, logging_output + + def valid_step(self, sample, model, criterion): + for task_name in self.multitask_tasks.keys(): + if task_name in model.multitask_decoders: + model.multitask_decoders[task_name].eval() + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + + if self.args.eval_inference: + hypos, inference_losses = self.valid_step_with_inference( + sample, model, self.generator + ) + for k, v in inference_losses.items(): + assert k not in logging_output + logging_output[k] = v + + return loss, sample_size, logging_output + + def valid_step_with_inference(self, sample, model, generator): + if self.args.target_is_code: + hypos = generator.generate([model], sample) + tgt_lens = ( + sample["target_lengths"] - 1 + ) * self.args.n_frames_per_step # strip + for b, (f, l) in enumerate(zip(sample["target"], tgt_lens)): + hypos[b][0]["targ_waveform"] = self.vocoder( + {"code": f[:l] - 4}, # remove , , , + dur_prediction=self.eval_gen_args.get("dur_prediction", False), + ) + if len(hypos[b][0]["tokens"]) > 0: + hypos[b][0]["waveform"] = self.vocoder( + {"code": hypos[b][0]["tokens"] - 4}, + dur_prediction=self.eval_gen_args.get("dur_prediction", False), + ) + else: + hypos[b][0]["waveform"] = torch.flip( + hypos[b][0]["targ_waveform"], dims=[0] + ) + else: + hypos = [ + [hypo] for hypo in generator.generate(model, sample, has_targ=True) + ] + + losses = { + "mcd_loss": 0.0, + "targ_frames": 0.0, + "pred_frames": 0.0, + "path_frames": 0.0, + "nins": 0.0, + "ndel": 0.0, + } + rets = batch_mel_cepstral_distortion( + [hypo[0]["targ_waveform"] for hypo in hypos], + [hypo[0]["waveform"] for hypo in hypos], + self.data_cfg.output_sample_rate, + normalize_type=None, + ) + for d, extra in rets: + pathmap = extra[-1] + losses["mcd_loss"] += d.item() + losses["targ_frames"] += pathmap.size(0) + losses["pred_frames"] += pathmap.size(1) + losses["path_frames"] += pathmap.sum().item() + losses["nins"] += (pathmap.sum(dim=1) - 1).sum().item() + losses["ndel"] += (pathmap.sum(dim=0) - 1).sum().item() + losses["norm_frames"] = losses[ + f"{getattr(self.args, 'mcd_normalize_type', 'targ')}_frames" + ] + + return hypos, losses + + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): + with torch.no_grad(): + if self._infer_tgt_lang_id is not None: + return generator.generate( + models, + sample, + prefix_tokens=prefix_tokens, + constraints=constraints, + bos_token=self._infer_tgt_lang_id, + ) + else: + return super().inference_step( + generator, + models, + sample, + prefix_tokens=prefix_tokens, + constraints=constraints, + ) diff --git a/fairseq/fairseq/tasks/speech_to_text.py b/fairseq/fairseq/tasks/speech_to_text.py new file mode 100644 index 0000000000000000000000000000000000000000..884082112a6763e0edb126e7ee21c83e40127823 --- /dev/null +++ b/fairseq/fairseq/tasks/speech_to_text.py @@ -0,0 +1,350 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from argparse import Namespace +from pathlib import Path +from typing import List + +from fairseq.data import Dictionary, encoders +from fairseq.data.audio.audio_utils import get_features_or_waveform +from fairseq.data.audio.data_cfg import MultitaskConfig +from fairseq.data.audio.speech_to_text_dataset import ( + S2TDataConfig, + SpeechToTextDataset, + SpeechToTextDatasetCreator, + TextTargetMultitaskData, +) +from fairseq.tasks import LegacyFairseqTask, register_task + +logger = logging.getLogger(__name__) + + +@register_task("speech_to_text") +class SpeechToTextTask(LegacyFairseqTask): + @classmethod + def add_args(cls, parser): + parser.add_argument("data", help="manifest root path") + parser.add_argument( + "--config-yaml", + type=str, + default="config.yaml", + help="Configuration YAML filename (under manifest root)", + ) + parser.add_argument( + "--multitask-config-yaml", + type=str, + default=None, + help="Configuration YAML filename for the multitasks (under manifest root)", + ) + parser.add_argument( + "--max-source-positions", + default=6000, + type=int, + metavar="N", + help="max number of tokens in the source sequence", + ) + parser.add_argument( + "--max-target-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the target sequence", + ) + + def __init__(self, args, tgt_dict): + super().__init__(args) + self.tgt_dict = tgt_dict + self.data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml) + self.speaker_to_id = self._get_speaker_to_id() + if ( + self.data_cfg.prepend_tgt_lang_tag + and self.data_cfg.prepend_bos_and_append_tgt_lang_tag + ): + raise ValueError( + "Please set only one of the two options to avoid adding target token multiple times" + ) + + self.multitask_tasks = {} + self.tgt_dict_mt = None + self.eos_token_mt = None + if getattr(args, "multitask_config_yaml", None) is not None: + multitask_cfg = MultitaskConfig( + Path(args.data) / args.multitask_config_yaml + ) + first_pass_task_idx = multitask_cfg.first_pass_decoder_task_index + for i, (task_name, task_config) in enumerate( + multitask_cfg.get_all_tasks().items() + ): + task_obj = DummyMultiTask( + task_config, + task_config.tgt_dict, + first_pass=i == first_pass_task_idx, + ) + self.multitask_tasks[task_name] = task_obj + if task_obj.is_first_pass_decoder: + self.tgt_dict_mt = task_obj.target_dictionary + if task_config.prepend_bos_and_append_tgt_lang_tag: + self.eos_token_mt = task_config.eos_token + assert not isinstance(self.eos_token_mt, List) + + if not self.eos_token_mt: + raise Warning( + "Please provide eos_token in --multitask-config-yaml to replace eos in sequence generator" + ) + + def _get_speaker_to_id(self): + speaker_to_id = None + speaker_set_filename = self.data_cfg.config.get("speaker_set_filename") + if speaker_set_filename is not None: + speaker_set_path = Path(self.args.data) / speaker_set_filename + with open(speaker_set_path) as f: + speaker_to_id = {r.strip(): i for i, r in enumerate(f)} + return speaker_to_id + + @classmethod + def setup_task(cls, args, **kwargs): + data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml) + dict_path = Path(args.data) / data_cfg.vocab_filename + if not dict_path.is_file(): + raise FileNotFoundError(f"Dict not found: {dict_path.as_posix()}") + tgt_dict = Dictionary.load(dict_path.as_posix()) + logger.info( + f"dictionary size ({data_cfg.vocab_filename}): " f"{len(tgt_dict):,}" + ) + + if getattr(args, "train_subset", None) is not None: + if not all(s.startswith("train") for s in args.train_subset.split(",")): + raise ValueError('Train splits should be named like "train*".') + return cls(args, tgt_dict) + + def build_criterion(self, args): + from fairseq import criterions + + if self.data_cfg.prepend_tgt_lang_tag and args.ignore_prefix_size != 1: + raise ValueError( + 'Please set "--ignore-prefix-size 1" since ' + "target language ID token is prepended as BOS." + ) + return criterions.build_criterion(args, self) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + is_train_split = split.startswith("train") + pre_tokenizer = self.build_tokenizer(self.args) + bpe_tokenizer = self.build_bpe(self.args) + self.datasets[split] = SpeechToTextDatasetCreator.from_tsv( + root=self.args.data, + cfg=self.data_cfg, + splits=split, + tgt_dict=self.tgt_dict, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, + is_train_split=is_train_split, + epoch=epoch, + seed=self.args.seed, + speaker_to_id=self.speaker_to_id, + multitask=self.multitask_tasks, + ) + + @property + def target_dictionary(self): + return self.tgt_dict + + @property + def target_dictionary_mt(self): + return self.tgt_dict_mt + + @property + def source_dictionary(self): + return None + + def max_positions(self): + return self.args.max_source_positions, self.args.max_target_positions + + def build_model(self, args, from_checkpoint=False): + args.input_feat_per_channel = self.data_cfg.input_feat_per_channel + args.input_channels = self.data_cfg.input_channels + args.speaker_to_id = self.speaker_to_id + return super(SpeechToTextTask, self).build_model(args, from_checkpoint) + + def build_generator_dual_decoder( + self, + models, + args, + extra_gen_cls_kwargs, + ): + from examples.speech_to_speech.unity.sequence_generator_multi_decoder import ( + MultiDecoderSequenceGenerator, + ) + + lang_token_ids_aux = { + i + for s, i in self.tgt_dict_mt.indices.items() + if TextTargetMultitaskData.is_lang_tag(s) + } + + extra_gen_cls_kwargs["symbols_to_strip_from_output"].update(lang_token_ids_aux) + + eos_id_mt = ( + self.tgt_dict_mt.index(self.eos_token_mt) if self.eos_token_mt else None + ) + assert eos_id_mt != self.tgt_dict_mt.unk() + extra_gen_cls_kwargs["eos_mt"] = eos_id_mt + + return MultiDecoderSequenceGenerator( + models, + self.target_dictionary, + self.target_dictionary_mt, + beam_size=max(1, getattr(args, "beam", 1)), + beam_size_mt=max(1, getattr(args, "beam_mt", 1)), + max_len_a=getattr(args, "max_len_a", 0), + max_len_b=getattr(args, "max_len_b", 200), + max_len_a_mt=getattr(args, "max_len_a_mt", 0), + max_len_b_mt=getattr(args, "max_len_b_mt", 0), + min_len=getattr(args, "min_len", 1), + normalize_scores=(not getattr(args, "unnormalized", False)), + len_penalty=getattr(args, "lenpen", 1), + len_penalty_mt=getattr(args, "lenpen_mt", 1), + unk_penalty=getattr(args, "unkpen", 0), + temperature=getattr(args, "temperature", 1.0), + match_source_len=getattr(args, "match_source_len", False), + no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), + **extra_gen_cls_kwargs, + ) + + def build_generator( + self, + models, + args, + seq_gen_cls=None, + extra_gen_cls_kwargs=None, + ): + if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1: + raise ValueError( + 'Please set "--prefix-size 1" since ' + "target language ID token is prepended as BOS." + ) + lang_token_ids = { + i + for s, i in self.tgt_dict.indices.items() + if SpeechToTextDataset.is_lang_tag(s) + } + + if extra_gen_cls_kwargs is None: + extra_gen_cls_kwargs = {} + extra_gen_cls_kwargs["symbols_to_strip_from_output"] = lang_token_ids + + eos_token = ( + args.eos_token + if "eos_token" in args and args.eos_token is not None + else self.data_cfg.config.get("eos_token", None) + ) + + if self.data_cfg.prepend_bos_and_append_tgt_lang_tag and not eos_token: + raise Warning( + "Please provide --eos_token to replace eos in sequence generator" + ) + + eos_id = self.tgt_dict.index(eos_token) if eos_token else None + extra_gen_cls_kwargs["eos"] = eos_id + + has_dual_decoder = getattr(models[0], "mt_task_name", None) is not None + + if has_dual_decoder: + return self.build_generator_dual_decoder( + models, + args, + extra_gen_cls_kwargs=extra_gen_cls_kwargs, + ) + else: + return super().build_generator( + models, + args, + seq_gen_cls=None, + extra_gen_cls_kwargs=extra_gen_cls_kwargs, + ) + + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): + for task_name, task_obj in self.multitask_tasks.items(): + criterion.set_multitask_loss_weight( + task_name, task_obj.args.get_loss_weight(update_num) + ) + if task_name in model.multitask_decoders: + model.multitask_decoders[task_name].train() + + loss, sample_size, logging_output = super().train_step( + sample, model, criterion, optimizer, update_num, ignore_grad + ) + return loss, sample_size, logging_output + + def valid_step(self, sample, model, criterion): + for task_name, task_obj in self.multitask_tasks.items(): + if task_name in model.multitask_decoders: + model.multitask_decoders[task_name].eval() + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + + return loss, sample_size, logging_output + + def build_tokenizer(self, args): + logger.info(f"pre-tokenizer: {self.data_cfg.pre_tokenizer}") + return encoders.build_tokenizer(Namespace(**self.data_cfg.pre_tokenizer)) + + def build_bpe(self, args): + logger.info(f"tokenizer: {self.data_cfg.bpe_tokenizer}") + return encoders.build_bpe(Namespace(**self.data_cfg.bpe_tokenizer)) + + def get_interactive_tokens_and_lengths(self, lines, encode_fn): + n_frames = [get_features_or_waveform(p).shape[0] for p in lines] + return lines, n_frames + + def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): + return SpeechToTextDataset( + "interactive", False, self.data_cfg, src_tokens, src_lengths + ) + + +class DummyMultiTask(LegacyFairseqTask): + def __init__(self, args, tgt_dict, first_pass=False): + super().__init__(args) + self.tgt_dict = tgt_dict + self.first_pass = first_pass + + @property + def target_dictionary(self): + return self.tgt_dict + + @property + def is_first_pass_decoder(self): + return self.first_pass + + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): + if self.args.decoder_type == "ctc": + model = models[0] # only support single model + encoder_out = model(**sample) + if hasattr(model, "get_logits"): + emissions = model.get_logits( + encoder_out + ) # no need to normalize emissions + else: + emissions = model.get_normalized_probs(encoder_out, log_probs=True) + return generator.decode( + emissions.transpose(0, 1).float().cpu().contiguous() + ) + else: + raise NotImplementedError("only ctc decoder is supported at the moment") + + def build_generator( + self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None + ): + if self.args.decoder_type == "ctc": + from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder + + return W2lViterbiDecoder(args, self.tgt_dict) + else: + raise NotImplementedError("only ctc decoder is supported at the moment") diff --git a/fairseq/fairseq/tasks/speech_ulm_task.py b/fairseq/fairseq/tasks/speech_ulm_task.py new file mode 100644 index 0000000000000000000000000000000000000000..b9d3019d5049f4f0ba1673d8b3e77c99db951887 --- /dev/null +++ b/fairseq/fairseq/tasks/speech_ulm_task.py @@ -0,0 +1,224 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +import sys +import torch +from dataclasses import dataclass, field +from typing import List, Optional, Tuple + +from fairseq.data import Dictionary +from fairseq.data.codedataset import ExpressiveCodeDataConfig, CodeDataset +from fairseq.dataclass.configs import FairseqDataclass +from fairseq.tasks import register_task +from fairseq.tasks.fairseq_task import FairseqTask +from omegaconf import MISSING, DictConfig + + +logger = logging.getLogger(__name__) + + +class UnitDictionary(Dictionary): + """ + A fixed-sized Dictionary that operates on integer-valued tokens + wth a trivial (identity) token <-> id mapping. + Special symbols (bos, eos, ...) have ids above n_units. + """ + + def __init__( + self, + *, # begin keyword-only arguments + n_units, + bos="", + pad="", + eos="", + unk="", + extra_special_symbols=None, + clip=False, + ): + self.n_units = n_units + self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos + self.clip = clip + + self.symbols = [] + self.count = [] + self.indices = {} + for i in range(n_units): + self.add_symbol(str(i)) + + self.bos_index = self.add_symbol(bos) + self.pad_index = self.add_symbol(pad) + self.eos_index = self.add_symbol(eos) + self.unk_index = self.add_symbol(unk) + + if extra_special_symbols: + for s in extra_special_symbols: + self.add_symbol(s) + self.nspecial = len(self.symbols) + + def encode_line(self, line, append_eos=True, prepend_bos=False) -> torch.IntTensor: + words = [int(x) for x in line.split()] + if self.clip: + words = [min(self.n_units - 1, word) for word in words] + if prepend_bos: + words = [self.bos_index] + words + if append_eos: + words.append(self.eos_index) + ids = torch.IntTensor(words) + return ids + + +@dataclass +class SpeechUnitModelingConfig(FairseqDataclass): + data: str = field(default=MISSING, metadata={"help": "Path to data config.json"}) + max_token_duration: int = field( + default=20, metadata={"help": "all token durations are capped to this value"} + ) + tokens_per_sample: int = field( + default=1024, metadata={"help": "tokens in a sample"} + ) + max_target_positions: int = field( + default=1024, metadata={"help": "max target positions"} + ) + + # duration modeling + ignore_duration_input: bool = field( + default=False, metadata={"help": "whether token durations should be zeroed out"} + ) + discrete_duration: bool = field( + default=False, metadata={"help": "treat duration as discrete variable"} + ) + # F0 modeling + ignore_f0_input: bool = field( + default=False, metadata={"help": "whether F0 should be zeroed out"} + ) + discrete_f0: bool = field( + default=False, metadata={"help": "load quantized f0. get bin from config"} + ) + log_f0: bool = field( + default=False, metadata={"help": "whether f0 should be modeled in log space"} + ) + normalize_f0_mean: bool = field( + default=False, metadata={"help": "whether normalize f0 by speaker mean"} + ) + normalize_f0_std: bool = field( + default=False, metadata={"help": "whether normalize f0 by speaker stddev"} + ) + interpolate_f0: bool = field( + default=False, + metadata={"help": "whether interpolate f0 for non-voiced segments"}, + ) + + # input/output streams + stream_shifts: str = field( + default="0,0", + metadata={ + "help": ( + "comma-separated integer list denoting right-shift for " + "duration and pitch streams" + ) + }, + ) + + +@register_task("speech_unit_modeling", dataclass=SpeechUnitModelingConfig) +class SpeechUnitLanguageModelingTask(FairseqTask): + def __init__(self, cfg: SpeechUnitModelingConfig) -> None: + super().__init__(cfg) + assert not self.cfg.normalize_f0_std or self.cfg.normalize_f0_mean + + self.data_config = ExpressiveCodeDataConfig(cfg.data) + self._source_dictionary = self._target_dictionary = UnitDictionary( + n_units=self.data_config.n_units + ) + self._source_duration_dictionary = self._target_duration_dictionary = ( + UnitDictionary(n_units=self.cfg.max_token_duration + 1, clip=True) + if self.cfg.discrete_duration + else None + ) + self._source_f0_dictionary = self._target_f0_dictionary = ( + UnitDictionary(n_units=self.data_config.f0_vq_n_units) + if self.cfg.discrete_f0 + else None + ) + + self._channel_names = ["token", "duration", "f0"] + self._channel_sizes = [ + len(self.target_dictionary), + len(self.target_duration_dictionary) if self.cfg.discrete_duration else 1, + len(self.target_f0_dictionary) if self.cfg.discrete_f0 else 1, + ] + + @property + def source_dictionary(self) -> Optional[Dictionary]: + return self._source_dictionary + + @property + def source_duration_dictionary(self) -> Optional[Dictionary]: + return self._source_duration_dictionary + + @property + def source_f0_dictionary(self) -> Optional[Dictionary]: + return self._source_f0_dictionary + + @property + def channel_names(self) -> List[str]: + return self._channel_names + + @property + def channel_sizes(self) -> List[int]: + return self._channel_sizes + + @property + def dictionary(self) -> Optional[Dictionary]: + return self._source_dictionary + + @property + def target_dictionary(self) -> Optional[Dictionary]: + return self._target_dictionary + + @property + def target_duration_dictionary(self) -> Optional[Dictionary]: + return self._target_duration_dictionary + + @property + def target_f0_dictionary(self) -> Optional[Dictionary]: + return self._target_f0_dictionary + + @property + def dictionaries(self) -> List[Dictionary]: + return [self._dictionaries[l] for l in self.cfg.labels] + + @classmethod + def setup_task( + cls, cfg: SpeechUnitModelingConfig, **kwargs + ) -> "SpeechUnitLanguageModelingTask": + return cls(cfg) + + def load_dataset(self, split: str, **kwargs) -> None: + self.datasets[split] = CodeDataset( + manifest=self.data_config.manifests[split], + dictionary=self.source_dictionary, + dur_dictionary=self.source_duration_dictionary, + f0_dictionary=self.source_f0_dictionary, + config=self.data_config, + discrete_dur=self.cfg.discrete_duration, + discrete_f0=self.cfg.discrete_f0, + log_f0=self.cfg.log_f0, + normalize_f0_mean=self.cfg.normalize_f0_mean, + normalize_f0_std=self.cfg.normalize_f0_std, + interpolate_f0=self.cfg.interpolate_f0, + shifts=self.cfg.stream_shifts, + ) + + def max_positions(self) -> Tuple[int, int]: + return (sys.maxsize, sys.maxsize) + + def build_criterion(self, cfg: DictConfig): + import fairseq.criterions + + return fairseq.criterions.build_criterion(cfg, self) diff --git a/fairseq/fairseq/tasks/text_to_speech.py b/fairseq/fairseq/tasks/text_to_speech.py new file mode 100644 index 0000000000000000000000000000000000000000..82e7e6643af719a9bc0b4d5ef446365b8ef7e8fb --- /dev/null +++ b/fairseq/fairseq/tasks/text_to_speech.py @@ -0,0 +1,501 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import os.path as op + +import torch +import torch.nn.functional as F +import numpy as np + +from fairseq.data.audio.text_to_speech_dataset import TextToSpeechDatasetCreator +from fairseq.tasks import register_task +from fairseq.tasks.speech_to_text import SpeechToTextTask +from fairseq.speech_generator import ( + AutoRegressiveSpeechGenerator, + NonAutoregressiveSpeechGenerator, + TeacherForcingAutoRegressiveSpeechGenerator, +) + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger(__name__) + + +try: + from tensorboardX import SummaryWriter +except ImportError: + logger.info("Please install tensorboardX: pip install tensorboardX") + SummaryWriter = None + + +@register_task("text_to_speech") +class TextToSpeechTask(SpeechToTextTask): + @staticmethod + def add_args(parser): + parser.add_argument("data", help="manifest root path") + parser.add_argument( + "--config-yaml", + type=str, + default="config.yaml", + help="Configuration YAML filename (under manifest root)", + ) + parser.add_argument( + "--max-source-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the source sequence", + ) + parser.add_argument( + "--max-target-positions", + default=1200, + type=int, + metavar="N", + help="max number of tokens in the target sequence", + ) + parser.add_argument("--n-frames-per-step", type=int, default=1) + parser.add_argument("--eos-prob-threshold", type=float, default=0.5) + parser.add_argument("--eval-inference", action="store_true") + parser.add_argument("--eval-tb-nsample", type=int, default=8) + parser.add_argument("--vocoder", type=str, default="griffin_lim") + parser.add_argument("--spec-bwd-max-iter", type=int, default=8) + + def __init__(self, args, src_dict): + super().__init__(args, src_dict) + self.src_dict = src_dict + self.sr = self.data_cfg.config.get("features").get("sample_rate") + + self.tensorboard_writer = None + self.tensorboard_dir = "" + if args.tensorboard_logdir and SummaryWriter is not None: + self.tensorboard_dir = os.path.join(args.tensorboard_logdir, "valid_extra") + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + is_train_split = split.startswith("train") + pre_tokenizer = self.build_tokenizer(self.args) + bpe_tokenizer = self.build_bpe(self.args) + self.datasets[split] = TextToSpeechDatasetCreator.from_tsv( + self.args.data, + self.data_cfg, + split, + self.src_dict, + pre_tokenizer, + bpe_tokenizer, + is_train_split=is_train_split, + epoch=epoch, + seed=self.args.seed, + n_frames_per_step=self.args.n_frames_per_step, + speaker_to_id=self.speaker_to_id, + ) + + @property + def target_dictionary(self): + return None + + @property + def source_dictionary(self): + return self.src_dict + + def get_speaker_embeddings_path(self): + speaker_emb_path = None + if self.data_cfg.config.get("speaker_emb_filename") is not None: + speaker_emb_path = op.join( + self.args.data, self.data_cfg.config.get("speaker_emb_filename") + ) + return speaker_emb_path + + @classmethod + def get_speaker_embeddings(cls, args): + embed_speaker = None + if args.speaker_to_id is not None: + if args.speaker_emb_path is None: + embed_speaker = torch.nn.Embedding( + len(args.speaker_to_id), args.speaker_embed_dim + ) + else: + speaker_emb_mat = np.load(args.speaker_emb_path) + assert speaker_emb_mat.shape[1] == args.speaker_embed_dim + embed_speaker = torch.nn.Embedding.from_pretrained( + torch.from_numpy(speaker_emb_mat), + freeze=True, + ) + logger.info( + f"load speaker embeddings from {args.speaker_emb_path}. " + f"train embedding? {embed_speaker.weight.requires_grad}\n" + f"embeddings:\n{speaker_emb_mat}" + ) + return embed_speaker + + def build_model(self, cfg, from_checkpoint=False): + cfg.pitch_min = self.data_cfg.config["features"].get("pitch_min", None) + cfg.pitch_max = self.data_cfg.config["features"].get("pitch_max", None) + cfg.energy_min = self.data_cfg.config["features"].get("energy_min", None) + cfg.energy_max = self.data_cfg.config["features"].get("energy_max", None) + cfg.speaker_emb_path = self.get_speaker_embeddings_path() + model = super().build_model(cfg, from_checkpoint) + self.generator = None + if getattr(cfg, "eval_inference", False): + self.generator = self.build_generator([model], cfg) + return model + + def build_generator(self, models, cfg, vocoder=None, **unused): + if vocoder is None: + vocoder = self.build_default_vocoder() + model = models[0] + if getattr(model, "NON_AUTOREGRESSIVE", False): + return NonAutoregressiveSpeechGenerator(model, vocoder, self.data_cfg) + else: + generator = AutoRegressiveSpeechGenerator + if getattr(cfg, "teacher_forcing", False): + generator = TeacherForcingAutoRegressiveSpeechGenerator + logger.info("Teacher forcing mode for generation") + return generator( + model, + vocoder, + self.data_cfg, + max_iter=self.args.max_target_positions, + eos_prob_threshold=self.args.eos_prob_threshold, + ) + + def build_default_vocoder(self): + from fairseq.models.text_to_speech.vocoder import get_vocoder + + vocoder = get_vocoder(self.args, self.data_cfg) + if torch.cuda.is_available() and not self.args.cpu: + vocoder = vocoder.cuda() + else: + vocoder = vocoder.cpu() + return vocoder + + def valid_step(self, sample, model, criterion): + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + + if getattr(self.args, "eval_inference", False): + hypos, inference_losses = self.valid_step_with_inference( + sample, model, self.generator + ) + for k, v in inference_losses.items(): + assert k not in logging_output + logging_output[k] = v + + picked_id = 0 + if self.tensorboard_dir and (sample["id"] == picked_id).any(): + self.log_tensorboard( + sample, + hypos[: self.args.eval_tb_nsample], + model._num_updates, + is_na_model=getattr(model, "NON_AUTOREGRESSIVE", False), + ) + return loss, sample_size, logging_output + + def valid_step_with_inference(self, sample, model, generator): + hypos = generator.generate(model, sample, has_targ=True) + + losses = { + "mcd_loss": 0.0, + "targ_frames": 0.0, + "pred_frames": 0.0, + "nins": 0.0, + "ndel": 0.0, + } + rets = batch_mel_cepstral_distortion( + [hypo["targ_waveform"] for hypo in hypos], + [hypo["waveform"] for hypo in hypos], + self.sr, + normalize_type=None, + ) + for d, extra in rets: + pathmap = extra[-1] + losses["mcd_loss"] += d.item() + losses["targ_frames"] += pathmap.size(0) + losses["pred_frames"] += pathmap.size(1) + losses["nins"] += (pathmap.sum(dim=1) - 1).sum().item() + losses["ndel"] += (pathmap.sum(dim=0) - 1).sum().item() + + return hypos, losses + + def log_tensorboard(self, sample, hypos, num_updates, is_na_model=False): + if self.tensorboard_writer is None: + self.tensorboard_writer = SummaryWriter(self.tensorboard_dir) + tb_writer = self.tensorboard_writer + for b in range(len(hypos)): + idx = sample["id"][b] + text = sample["src_texts"][b] + targ = hypos[b]["targ_feature"] + pred = hypos[b]["feature"] + attn = hypos[b]["attn"] + + if is_na_model: + data = plot_tts_output( + [targ.transpose(0, 1), pred.transpose(0, 1)], + [f"target (idx={idx})", "output"], + attn, + "alignment", + ret_np=True, + suptitle=text, + ) + else: + eos_prob = hypos[b]["eos_prob"] + data = plot_tts_output( + [targ.transpose(0, 1), pred.transpose(0, 1), attn], + [f"target (idx={idx})", "output", "alignment"], + eos_prob, + "eos prob", + ret_np=True, + suptitle=text, + ) + + tb_writer.add_image( + f"inference_sample_{b}", data, num_updates, dataformats="HWC" + ) + + if hypos[b]["waveform"] is not None: + targ_wave = hypos[b]["targ_waveform"].detach().cpu().float() + pred_wave = hypos[b]["waveform"].detach().cpu().float() + tb_writer.add_audio( + f"inference_targ_{b}", targ_wave, num_updates, sample_rate=self.sr + ) + tb_writer.add_audio( + f"inference_pred_{b}", pred_wave, num_updates, sample_rate=self.sr + ) + + +def save_figure_to_numpy(fig): + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + return data + + +DEFAULT_V_MIN = np.log(1e-5) + + +def plot_tts_output( + data_2d, + title_2d, + data_1d, + title_1d, + figsize=(24, 4), + v_min=DEFAULT_V_MIN, + v_max=3, + ret_np=False, + suptitle="", +): + try: + import matplotlib.pyplot as plt + from mpl_toolkits.axes_grid1 import make_axes_locatable + except ImportError: + raise ImportError("Please install Matplotlib: pip install matplotlib") + + data_2d = [ + x.detach().cpu().float().numpy() if isinstance(x, torch.Tensor) else x + for x in data_2d + ] + fig, axes = plt.subplots(1, len(data_2d) + 1, figsize=figsize) + if suptitle: + fig.suptitle(suptitle[:400]) # capped at 400 chars + axes = [axes] if len(data_2d) == 0 else axes + for ax, x, name in zip(axes, data_2d, title_2d): + ax.set_title(name) + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="5%", pad=0.05) + im = ax.imshow( + x, + origin="lower", + aspect="auto", + vmin=max(x.min(), v_min), + vmax=min(x.max(), v_max), + ) + fig.colorbar(im, cax=cax, orientation="vertical") + + if isinstance(data_1d, torch.Tensor): + data_1d = data_1d.detach().cpu().numpy() + axes[-1].plot(data_1d) + axes[-1].set_title(title_1d) + plt.tight_layout() + + if ret_np: + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close(fig) + return data + + +def antidiag_indices(offset, min_i=0, max_i=None, min_j=0, max_j=None): + """ + for a (3, 4) matrix with min_i=1, max_i=3, min_j=1, max_j=4, outputs + + offset=2 (1, 1), + offset=3 (2, 1), (1, 2) + offset=4 (2, 2), (1, 3) + offset=5 (2, 3) + + constraints: + i + j = offset + min_j <= j < max_j + min_i <= offset - j < max_i + """ + if max_i is None: + max_i = offset + 1 + if max_j is None: + max_j = offset + 1 + min_j = max(min_j, offset - max_i + 1, 0) + max_j = min(max_j, offset - min_i + 1, offset + 1) + j = torch.arange(min_j, max_j) + i = offset - j + return torch.stack([i, j]) + + +def batch_dynamic_time_warping(distance, shapes=None): + """full batched DTW without any constraints + + distance: (batchsize, max_M, max_N) matrix + shapes: (batchsize,) vector specifying (M, N) for each entry + """ + # ptr: 0=left, 1=up-left, 2=up + ptr2dij = {0: (0, -1), 1: (-1, -1), 2: (-1, 0)} + + bsz, m, n = distance.size() + cumdist = torch.zeros_like(distance) + backptr = torch.zeros_like(distance).type(torch.int32) - 1 + + # initialize + cumdist[:, 0, :] = distance[:, 0, :].cumsum(dim=-1) + cumdist[:, :, 0] = distance[:, :, 0].cumsum(dim=-1) + backptr[:, 0, :] = 0 + backptr[:, :, 0] = 2 + + # DP with optimized anti-diagonal parallelization, O(M+N) steps + for offset in range(2, m + n - 1): + ind = antidiag_indices(offset, 1, m, 1, n) + c = torch.stack( + [ + cumdist[:, ind[0], ind[1] - 1], + cumdist[:, ind[0] - 1, ind[1] - 1], + cumdist[:, ind[0] - 1, ind[1]], + ], + dim=2, + ) + v, b = c.min(axis=-1) + backptr[:, ind[0], ind[1]] = b.int() + cumdist[:, ind[0], ind[1]] = v + distance[:, ind[0], ind[1]] + + # backtrace + pathmap = torch.zeros_like(backptr) + for b in range(bsz): + i = m - 1 if shapes is None else (shapes[b][0] - 1).item() + j = n - 1 if shapes is None else (shapes[b][1] - 1).item() + dtwpath = [(i, j)] + while (i != 0 or j != 0) and len(dtwpath) < 10000: + assert i >= 0 and j >= 0 + di, dj = ptr2dij[backptr[b, i, j].item()] + i, j = i + di, j + dj + dtwpath.append((i, j)) + dtwpath = dtwpath[::-1] + indices = torch.from_numpy(np.array(dtwpath)) + pathmap[b, indices[:, 0], indices[:, 1]] = 1 + + return cumdist, backptr, pathmap + + +def compute_l2_dist(x1, x2): + """compute an (m, n) L2 distance matrix from (m, d) and (n, d) matrices""" + return torch.cdist(x1.unsqueeze(0), x2.unsqueeze(0), p=2).squeeze(0).pow(2) + + +def compute_rms_dist(x1, x2): + l2_dist = compute_l2_dist(x1, x2) + return (l2_dist / x1.size(1)).pow(0.5) + + +def get_divisor(pathmap, normalize_type): + if normalize_type is None: + return 1 + elif normalize_type == "len1": + return pathmap.size(0) + elif normalize_type == "len2": + return pathmap.size(1) + elif normalize_type == "path": + return pathmap.sum().item() + else: + raise ValueError(f"normalize_type {normalize_type} not supported") + + +def batch_compute_distortion(y1, y2, sr, feat_fn, dist_fn, normalize_type): + d, s, x1, x2 = [], [], [], [] + for cur_y1, cur_y2 in zip(y1, y2): + assert cur_y1.ndim == 1 and cur_y2.ndim == 1 + cur_x1 = feat_fn(cur_y1) + cur_x2 = feat_fn(cur_y2) + x1.append(cur_x1) + x2.append(cur_x2) + + cur_d = dist_fn(cur_x1, cur_x2) + d.append(cur_d) + s.append(d[-1].size()) + max_m = max(ss[0] for ss in s) + max_n = max(ss[1] for ss in s) + d = torch.stack( + [F.pad(dd, (0, max_n - dd.size(1), 0, max_m - dd.size(0))) for dd in d] + ) + s = torch.LongTensor(s).to(d.device) + cumdists, backptrs, pathmaps = batch_dynamic_time_warping(d, s) + + rets = [] + itr = zip(s, x1, x2, d, cumdists, backptrs, pathmaps) + for (m, n), cur_x1, cur_x2, dist, cumdist, backptr, pathmap in itr: + cumdist = cumdist[:m, :n] + backptr = backptr[:m, :n] + pathmap = pathmap[:m, :n] + divisor = get_divisor(pathmap, normalize_type) + + distortion = cumdist[-1, -1] / divisor + ret = distortion, (cur_x1, cur_x2, dist, cumdist, backptr, pathmap) + rets.append(ret) + return rets + + +def batch_mel_cepstral_distortion(y1, y2, sr, normalize_type="path", mfcc_fn=None): + """ + https://arxiv.org/pdf/2011.03568.pdf + + The root mean squared error computed on 13-dimensional MFCC using DTW for + alignment. MFCC features are computed from an 80-channel log-mel + spectrogram using a 50ms Hann window and hop of 12.5ms. + + y1: list of waveforms + y2: list of waveforms + sr: sampling rate + """ + + try: + import torchaudio + except ImportError: + raise ImportError("Please install torchaudio: pip install torchaudio") + + if mfcc_fn is None or mfcc_fn.sample_rate != sr: + melkwargs = { + "n_fft": int(0.05 * sr), + "win_length": int(0.05 * sr), + "hop_length": int(0.0125 * sr), + "f_min": 20, + "n_mels": 80, + "window_fn": torch.hann_window, + } + mfcc_fn = torchaudio.transforms.MFCC( + sr, n_mfcc=13, log_mels=True, melkwargs=melkwargs + ).to(y1[0].device) + return batch_compute_distortion( + y1, + y2, + sr, + lambda y: mfcc_fn(y).transpose(-1, -2), + compute_rms_dist, + normalize_type, + ) diff --git a/fairseq/fairseq/tasks/translation.py b/fairseq/fairseq/tasks/translation.py new file mode 100644 index 0000000000000000000000000000000000000000..6897ebe116a28c03b1632deff04d6d570398d2f0 --- /dev/null +++ b/fairseq/fairseq/tasks/translation.py @@ -0,0 +1,498 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +import itertools +import json +import logging +import os +from typing import Optional +from argparse import Namespace +from omegaconf import II + +import numpy as np +from fairseq import utils +from fairseq.logging import metrics +from fairseq.data import ( + AppendTokenDataset, + ConcatDataset, + LanguagePairDataset, + PrependTokenDataset, + StripTokenDataset, + TruncateDataset, + data_utils, + encoders, + indexed_dataset, +) +from fairseq.data.indexed_dataset import get_available_dataset_impl +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.tasks import FairseqTask, register_task + + +EVAL_BLEU_ORDER = 4 + + +logger = logging.getLogger(__name__) + + +def load_langpair_dataset( + data_path, + split, + src, + src_dict, + tgt, + tgt_dict, + combine, + dataset_impl, + upsample_primary, + left_pad_source, + left_pad_target, + max_source_positions, + max_target_positions, + prepend_bos=False, + load_alignments=False, + truncate_source=False, + append_source_id=False, + num_buckets=0, + shuffle=True, + pad_to_multiple=1, + prepend_bos_src=None, +): + def split_exists(split, src, tgt, lang, data_path): + filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) + return indexed_dataset.dataset_exists(filename, impl=dataset_impl) + + src_datasets = [] + tgt_datasets = [] + + for k in itertools.count(): + split_k = split + (str(k) if k > 0 else "") + + # infer langcode + if split_exists(split_k, src, tgt, src, data_path): + prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) + elif split_exists(split_k, tgt, src, src, data_path): + prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) + else: + if k > 0: + break + else: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, data_path) + ) + + src_dataset = data_utils.load_indexed_dataset( + prefix + src, src_dict, dataset_impl + ) + if truncate_source: + src_dataset = AppendTokenDataset( + TruncateDataset( + StripTokenDataset(src_dataset, src_dict.eos()), + max_source_positions - 1, + ), + src_dict.eos(), + ) + src_datasets.append(src_dataset) + + tgt_dataset = data_utils.load_indexed_dataset( + prefix + tgt, tgt_dict, dataset_impl + ) + if tgt_dataset is not None: + tgt_datasets.append(tgt_dataset) + + logger.info( + "{} {} {}-{} {} examples".format( + data_path, split_k, src, tgt, len(src_datasets[-1]) + ) + ) + + if not combine: + break + + assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 + + if len(src_datasets) == 1: + src_dataset = src_datasets[0] + tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None + else: + sample_ratios = [1] * len(src_datasets) + sample_ratios[0] = upsample_primary + src_dataset = ConcatDataset(src_datasets, sample_ratios) + if len(tgt_datasets) > 0: + tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) + else: + tgt_dataset = None + + if prepend_bos: + assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") + src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) + if tgt_dataset is not None: + tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) + elif prepend_bos_src is not None: + logger.info(f"prepending src bos: {prepend_bos_src}") + src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src) + + eos = None + if append_source_id: + src_dataset = AppendTokenDataset( + src_dataset, src_dict.index("[{}]".format(src)) + ) + if tgt_dataset is not None: + tgt_dataset = AppendTokenDataset( + tgt_dataset, tgt_dict.index("[{}]".format(tgt)) + ) + eos = tgt_dict.index("[{}]".format(tgt)) + + align_dataset = None + if load_alignments: + align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt)) + if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): + align_dataset = data_utils.load_indexed_dataset( + align_path, None, dataset_impl + ) + + tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None + return LanguagePairDataset( + src_dataset, + src_dataset.sizes, + src_dict, + tgt_dataset, + tgt_dataset_sizes, + tgt_dict, + left_pad_source=left_pad_source, + left_pad_target=left_pad_target, + align_dataset=align_dataset, + eos=eos, + num_buckets=num_buckets, + shuffle=shuffle, + pad_to_multiple=pad_to_multiple, + ) + + +@dataclass +class TranslationConfig(FairseqDataclass): + data: Optional[str] = field( + default=None, + metadata={ + "help": "colon separated path to data directories list, will be iterated upon during epochs " + "in round-robin manner; however, valid and test data are always in the first directory " + "to avoid the need for repeating them in all directories" + }, + ) + source_lang: Optional[str] = field( + default=None, + metadata={ + "help": "source language", + "argparse_alias": "-s", + }, + ) + target_lang: Optional[str] = field( + default=None, + metadata={ + "help": "target language", + "argparse_alias": "-t", + }, + ) + load_alignments: bool = field( + default=False, metadata={"help": "load the binarized alignments"} + ) + left_pad_source: bool = field( + default=True, metadata={"help": "pad the source on the left"} + ) + left_pad_target: bool = field( + default=False, metadata={"help": "pad the target on the left"} + ) + max_source_positions: int = field( + default=1024, metadata={"help": "max number of tokens in the source sequence"} + ) + max_target_positions: int = field( + default=1024, metadata={"help": "max number of tokens in the target sequence"} + ) + upsample_primary: int = field( + default=-1, metadata={"help": "the amount of upsample primary dataset"} + ) + truncate_source: bool = field( + default=False, metadata={"help": "truncate source to max-source-positions"} + ) + num_batch_buckets: int = field( + default=0, + metadata={ + "help": "if >0, then bucket source and target lengths into " + "N buckets and pad accordingly; this is useful on TPUs to minimize the number of compilations" + }, + ) + train_subset: str = II("dataset.train_subset") + dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II( + "dataset.dataset_impl" + ) + required_seq_len_multiple: int = II("dataset.required_seq_len_multiple") + + # options for reporting BLEU during validation + eval_bleu: bool = field( + default=False, metadata={"help": "evaluation with BLEU scores"} + ) + eval_bleu_args: Optional[str] = field( + default="{}", + metadata={ + "help": 'generation args for BLUE scoring, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string' + }, + ) + eval_bleu_detok: str = field( + default="space", + metadata={ + "help": "detokenize before computing BLEU (e.g., 'moses'); required if using --eval-bleu; " + "use 'space' to disable detokenization; see fairseq.data.encoders for other options" + }, + ) + eval_bleu_detok_args: Optional[str] = field( + default="{}", + metadata={"help": "args for building the tokenizer, if needed, as JSON string"}, + ) + eval_tokenized_bleu: bool = field( + default=False, metadata={"help": "compute tokenized BLEU instead of sacrebleu"} + ) + eval_bleu_remove_bpe: Optional[str] = field( + default=None, + metadata={ + "help": "remove BPE before computing BLEU", + "argparse_const": "@@ ", + }, + ) + eval_bleu_print_samples: bool = field( + default=False, metadata={"help": "print sample generations during validation"} + ) + + +@register_task("translation", dataclass=TranslationConfig) +class TranslationTask(FairseqTask): + """ + Translate from one (source) language to another (target) language. + + Args: + src_dict (~fairseq.data.Dictionary): dictionary for the source language + tgt_dict (~fairseq.data.Dictionary): dictionary for the target language + + .. note:: + + The translation task is compatible with :mod:`fairseq-train`, + :mod:`fairseq-generate` and :mod:`fairseq-interactive`. + """ + + cfg: TranslationConfig + + def __init__(self, cfg: TranslationConfig, src_dict, tgt_dict): + super().__init__(cfg) + self.src_dict = src_dict + self.tgt_dict = tgt_dict + + @classmethod + def setup_task(cls, cfg: TranslationConfig, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + + paths = utils.split_paths(cfg.data) + assert len(paths) > 0 + # find language pair automatically + if cfg.source_lang is None or cfg.target_lang is None: + cfg.source_lang, cfg.target_lang = data_utils.infer_language_pair(paths[0]) + if cfg.source_lang is None or cfg.target_lang is None: + raise Exception( + "Could not infer language pair, please provide it explicitly" + ) + + # load dictionaries + src_dict = cls.load_dictionary( + os.path.join(paths[0], "dict.{}.txt".format(cfg.source_lang)) + ) + tgt_dict = cls.load_dictionary( + os.path.join(paths[0], "dict.{}.txt".format(cfg.target_lang)) + ) + assert src_dict.pad() == tgt_dict.pad() + assert src_dict.eos() == tgt_dict.eos() + assert src_dict.unk() == tgt_dict.unk() + logger.info("[{}] dictionary: {} types".format(cfg.source_lang, len(src_dict))) + logger.info("[{}] dictionary: {} types".format(cfg.target_lang, len(tgt_dict))) + + return cls(cfg, src_dict, tgt_dict) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + paths = utils.split_paths(self.cfg.data) + assert len(paths) > 0 + if split != self.cfg.train_subset: + # if not training data set, use the first shard for valid and test + paths = paths[:1] + data_path = paths[(epoch - 1) % len(paths)] + + # infer langcode + src, tgt = self.cfg.source_lang, self.cfg.target_lang + + self.datasets[split] = load_langpair_dataset( + data_path, + split, + src, + self.src_dict, + tgt, + self.tgt_dict, + combine=combine, + dataset_impl=self.cfg.dataset_impl, + upsample_primary=self.cfg.upsample_primary, + left_pad_source=self.cfg.left_pad_source, + left_pad_target=self.cfg.left_pad_target, + max_source_positions=self.cfg.max_source_positions, + max_target_positions=self.cfg.max_target_positions, + load_alignments=self.cfg.load_alignments, + truncate_source=self.cfg.truncate_source, + num_buckets=self.cfg.num_batch_buckets, + shuffle=(split != "test"), + pad_to_multiple=self.cfg.required_seq_len_multiple, + ) + + def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): + return LanguagePairDataset( + src_tokens, + src_lengths, + self.source_dictionary, + tgt_dict=self.target_dictionary, + constraints=constraints, + ) + + def build_model(self, cfg, from_checkpoint=False): + model = super().build_model(cfg, from_checkpoint) + if self.cfg.eval_bleu: + detok_args = json.loads(self.cfg.eval_bleu_detok_args) + self.tokenizer = encoders.build_tokenizer( + Namespace(tokenizer=self.cfg.eval_bleu_detok, **detok_args) + ) + + gen_args = json.loads(self.cfg.eval_bleu_args) + self.sequence_generator = self.build_generator( + [model], Namespace(**gen_args) + ) + return model + + def valid_step(self, sample, model, criterion): + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + if self.cfg.eval_bleu: + bleu = self._inference_with_bleu(self.sequence_generator, sample, model) + logging_output["_bleu_sys_len"] = bleu.sys_len + logging_output["_bleu_ref_len"] = bleu.ref_len + # we split counts into separate entries so that they can be + # summed efficiently across workers using fast-stat-sync + assert len(bleu.counts) == EVAL_BLEU_ORDER + for i in range(EVAL_BLEU_ORDER): + logging_output["_bleu_counts_" + str(i)] = bleu.counts[i] + logging_output["_bleu_totals_" + str(i)] = bleu.totals[i] + return loss, sample_size, logging_output + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + if self.cfg.eval_bleu: + + def sum_logs(key): + import torch + + result = sum(log.get(key, 0) for log in logging_outputs) + if torch.is_tensor(result): + result = result.cpu() + return result + + counts, totals = [], [] + for i in range(EVAL_BLEU_ORDER): + counts.append(sum_logs("_bleu_counts_" + str(i))) + totals.append(sum_logs("_bleu_totals_" + str(i))) + + if max(totals) > 0: + # log counts as numpy arrays -- log_scalar will sum them correctly + metrics.log_scalar("_bleu_counts", np.array(counts)) + metrics.log_scalar("_bleu_totals", np.array(totals)) + metrics.log_scalar("_bleu_sys_len", sum_logs("_bleu_sys_len")) + metrics.log_scalar("_bleu_ref_len", sum_logs("_bleu_ref_len")) + + def compute_bleu(meters): + import inspect + + try: + from sacrebleu.metrics import BLEU + + comp_bleu = BLEU.compute_bleu + except ImportError: + # compatibility API for sacrebleu 1.x + import sacrebleu + + comp_bleu = sacrebleu.compute_bleu + + fn_sig = inspect.getfullargspec(comp_bleu)[0] + if "smooth_method" in fn_sig: + smooth = {"smooth_method": "exp"} + else: + smooth = {"smooth": "exp"} + bleu = comp_bleu( + correct=meters["_bleu_counts"].sum, + total=meters["_bleu_totals"].sum, + sys_len=int(meters["_bleu_sys_len"].sum), + ref_len=int(meters["_bleu_ref_len"].sum), + **smooth, + ) + return round(bleu.score, 2) + + metrics.log_derived("bleu", compute_bleu) + + def max_positions(self): + """Return the max sentence length allowed by the task.""" + return (self.cfg.max_source_positions, self.cfg.max_target_positions) + + @property + def source_dictionary(self): + """Return the source :class:`~fairseq.data.Dictionary`.""" + return self.src_dict + + @property + def target_dictionary(self): + """Return the target :class:`~fairseq.data.Dictionary`.""" + return self.tgt_dict + + def _inference_with_bleu(self, generator, sample, model): + import sacrebleu + + def decode(toks, escape_unk=False): + s = self.tgt_dict.string( + toks.int().cpu(), + self.cfg.eval_bleu_remove_bpe, + # The default unknown string in fairseq is ``, but + # this is tokenized by sacrebleu as `< unk >`, inflating + # BLEU scores. Instead, we use a somewhat more verbose + # alternative that is unlikely to appear in the real + # reference, but doesn't get split into multiple tokens. + unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"), + ) + if self.tokenizer: + s = self.tokenizer.decode(s) + return s + + gen_out = self.inference_step(generator, [model], sample, prefix_tokens=None) + hyps, refs = [], [] + for i in range(len(gen_out)): + hyps.append(decode(gen_out[i][0]["tokens"])) + refs.append( + decode( + utils.strip_pad(sample["target"][i], self.tgt_dict.pad()), + escape_unk=True, # don't count as matches to the hypo + ) + ) + if self.cfg.eval_bleu_print_samples: + logger.info("example hypothesis: " + hyps[0]) + logger.info("example reference: " + refs[0]) + if self.cfg.eval_tokenized_bleu: + return sacrebleu.corpus_bleu(hyps, [refs], tokenize="none") + else: + return sacrebleu.corpus_bleu(hyps, [refs]) diff --git a/fairseq/fairseq/tasks/translation_from_pretrained_bart.py b/fairseq/fairseq/tasks/translation_from_pretrained_bart.py new file mode 100644 index 0000000000000000000000000000000000000000..0fd7a5b29f0e34699b5d5ef7574bc39b8c6052c9 --- /dev/null +++ b/fairseq/fairseq/tasks/translation_from_pretrained_bart.py @@ -0,0 +1,132 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from fairseq import utils +from fairseq.data import LanguagePairDataset + +from . import register_task +from .translation import TranslationTask, load_langpair_dataset + + +@register_task("translation_from_pretrained_bart") +class TranslationFromPretrainedBARTTask(TranslationTask): + """ + Translate from source language to target language with a model initialized with a multilingual pretrain. + + Args: + src_dict (~fairseq.data.Dictionary): dictionary for the source language + tgt_dict (~fairseq.data.Dictionary): dictionary for the target language + + .. note:: + + The translation task is compatible with :mod:`fairseq-train`, + :mod:`fairseq-generate` and :mod:`fairseq-interactive`. + + The translation task provides the following additional command-line + arguments: + + .. argparse:: + :ref: fairseq.tasks.translation_parser + :prog: + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + # fmt: off + TranslationTask.add_args(parser) + parser.add_argument('--langs', type=str, metavar='LANG', + help='comma-separated list of monolingual language, ' + 'for example, "en,de,fr". These should match the ' + 'langs from pretraining (and be in the same order). ' + 'You should always add all pretraining language idx ' + 'during finetuning.') + parser.add_argument('--prepend-bos', action='store_true', + help='prepend bos token to each sentence, which matches ' + 'mBART pretraining') + # fmt: on + + def __init__(self, args, src_dict, tgt_dict): + super().__init__(args, src_dict, tgt_dict) + self.langs = args.langs.split(",") + for d in [src_dict, tgt_dict]: + for l in self.langs: + d.add_symbol("[{}]".format(l)) + d.add_symbol("") + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + paths = utils.split_paths(self.args.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + + # infer langcode + src, tgt = self.args.source_lang, self.args.target_lang + + self.datasets[split] = load_langpair_dataset( + data_path, + split, + src, + self.src_dict, + tgt, + self.tgt_dict, + combine=combine, + dataset_impl=self.args.dataset_impl, + upsample_primary=self.args.upsample_primary, + left_pad_source=self.args.left_pad_source, + left_pad_target=self.args.left_pad_target, + max_source_positions=getattr(self.args, "max_source_positions", 1024), + max_target_positions=getattr(self.args, "max_target_positions", 1024), + load_alignments=self.args.load_alignments, + prepend_bos=getattr(self.args, "prepend_bos", False), + append_source_id=True, + ) + + def build_generator(self, models, args, **unused): + if getattr(args, "score_reference", False): + from fairseq.sequence_scorer import SequenceScorer + + return SequenceScorer( + self.target_dictionary, + eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)), + ) + else: + from fairseq.sequence_generator import SequenceGenerator + + return SequenceGenerator( + models, + self.target_dictionary, + beam_size=getattr(args, "beam", 5), + max_len_a=getattr(args, "max_len_a", 0), + max_len_b=getattr(args, "max_len_b", 200), + min_len=getattr(args, "min_len", 1), + normalize_scores=(not getattr(args, "unnormalized", False)), + len_penalty=getattr(args, "lenpen", 1), + unk_penalty=getattr(args, "unkpen", 0), + temperature=getattr(args, "temperature", 1.0), + match_source_len=getattr(args, "match_source_len", False), + no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), + eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)), + ) + + def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): + src_lang_id = self.source_dictionary.index("[{}]".format(self.args.source_lang)) + source_tokens = [] + for s_t in src_tokens: + s_t = torch.cat([s_t, s_t.new(1).fill_(src_lang_id)]) + source_tokens.append(s_t) + dataset = LanguagePairDataset( + source_tokens, + src_lengths, + self.source_dictionary, + tgt_dict=self.target_dictionary, + constraints=constraints, + ) + return dataset diff --git a/fairseq/fairseq/tasks/translation_from_pretrained_xlm.py b/fairseq/fairseq/tasks/translation_from_pretrained_xlm.py new file mode 100644 index 0000000000000000000000000000000000000000..a05f2891524a8b23482e206c1742c3b816b77afb --- /dev/null +++ b/fairseq/fairseq/tasks/translation_from_pretrained_xlm.py @@ -0,0 +1,39 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary +from fairseq.tasks.translation import TranslationConfig, TranslationTask + +from . import register_task + + +@dataclass +class TranslationFromPretrainedXLMConfig(TranslationConfig): + pass + + +@register_task( + "translation_from_pretrained_xlm", dataclass=TranslationFromPretrainedXLMConfig +) +class TranslationFromPretrainedXLMTask(TranslationTask): + """ + Same as TranslationTask except use the MaskedLMDictionary class so that + we can load data that was binarized with the MaskedLMDictionary class. + + This task should be used for the entire training pipeline when we want to + train an NMT model from a pretrained XLM checkpoint: binarizing NMT data, + training NMT with the pretrained XLM checkpoint, and subsequent evaluation + of that trained model. + """ + + @classmethod + def load_dictionary(cls, filename): + """Load the masked LM dictionary from the filename + + Args: + filename (str): the filename + """ + return MaskedLMDictionary.load(filename) diff --git a/fairseq/fairseq/tasks/translation_lev.py b/fairseq/fairseq/tasks/translation_lev.py new file mode 100644 index 0000000000000000000000000000000000000000..b45fecd1f40ae43829ef43633a04dcbfd77a4136 --- /dev/null +++ b/fairseq/fairseq/tasks/translation_lev.py @@ -0,0 +1,195 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +import torch +from fairseq import utils +from fairseq.data import LanguagePairDataset +from fairseq.dataclass import ChoiceEnum +from fairseq.tasks import register_task +from fairseq.tasks.translation import ( + TranslationConfig, + TranslationTask, + load_langpair_dataset, +) +from fairseq.utils import new_arange + + +NOISE_CHOICES = ChoiceEnum(["random_delete", "random_mask", "no_noise", "full_mask"]) + + +@dataclass +class TranslationLevenshteinConfig(TranslationConfig): + noise: NOISE_CHOICES = field( + default="random_delete", + metadata={"help": "type of noise"}, + ) + + +@register_task("translation_lev", dataclass=TranslationLevenshteinConfig) +class TranslationLevenshteinTask(TranslationTask): + """ + Translation (Sequence Generation) task for Levenshtein Transformer + See `"Levenshtein Transformer" `_. + """ + + cfg: TranslationLevenshteinConfig + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + paths = utils.split_paths(self.cfg.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + + # infer langcode + src, tgt = self.cfg.source_lang, self.cfg.target_lang + + self.datasets[split] = load_langpair_dataset( + data_path, + split, + src, + self.src_dict, + tgt, + self.tgt_dict, + combine=combine, + dataset_impl=self.cfg.dataset_impl, + upsample_primary=self.cfg.upsample_primary, + left_pad_source=self.cfg.left_pad_source, + left_pad_target=self.cfg.left_pad_target, + max_source_positions=self.cfg.max_source_positions, + max_target_positions=self.cfg.max_target_positions, + prepend_bos=True, + ) + + def inject_noise(self, target_tokens): + def _random_delete(target_tokens): + pad = self.tgt_dict.pad() + bos = self.tgt_dict.bos() + eos = self.tgt_dict.eos() + + max_len = target_tokens.size(1) + target_mask = target_tokens.eq(pad) + target_score = target_tokens.clone().float().uniform_() + target_score.masked_fill_( + target_tokens.eq(bos) | target_tokens.eq(eos), 0.0 + ) + target_score.masked_fill_(target_mask, 1) + target_score, target_rank = target_score.sort(1) + target_length = target_mask.size(1) - target_mask.float().sum( + 1, keepdim=True + ) + + # do not delete and (we assign 0 score for them) + target_cutoff = ( + 2 + + ( + (target_length - 2) + * target_score.new_zeros(target_score.size(0), 1).uniform_() + ).long() + ) + target_cutoff = target_score.sort(1)[1] >= target_cutoff + + prev_target_tokens = ( + target_tokens.gather(1, target_rank) + .masked_fill_(target_cutoff, pad) + .gather(1, target_rank.masked_fill_(target_cutoff, max_len).sort(1)[1]) + ) + prev_target_tokens = prev_target_tokens[ + :, : prev_target_tokens.ne(pad).sum(1).max() + ] + + return prev_target_tokens + + def _random_mask(target_tokens): + pad = self.tgt_dict.pad() + bos = self.tgt_dict.bos() + eos = self.tgt_dict.eos() + unk = self.tgt_dict.unk() + + target_masks = ( + target_tokens.ne(pad) & target_tokens.ne(bos) & target_tokens.ne(eos) + ) + target_score = target_tokens.clone().float().uniform_() + target_score.masked_fill_(~target_masks, 2.0) + target_length = target_masks.sum(1).float() + target_length = target_length * target_length.clone().uniform_() + target_length = target_length + 1 # make sure to mask at least one token. + + _, target_rank = target_score.sort(1) + target_cutoff = new_arange(target_rank) < target_length[:, None].long() + prev_target_tokens = target_tokens.masked_fill( + target_cutoff.scatter(1, target_rank, target_cutoff), unk + ) + return prev_target_tokens + + def _full_mask(target_tokens): + pad = self.tgt_dict.pad() + bos = self.tgt_dict.bos() + eos = self.tgt_dict.eos() + unk = self.tgt_dict.unk() + + target_mask = ( + target_tokens.eq(bos) | target_tokens.eq(eos) | target_tokens.eq(pad) + ) + return target_tokens.masked_fill(~target_mask, unk) + + if self.cfg.noise == "random_delete": + return _random_delete(target_tokens) + elif self.cfg.noise == "random_mask": + return _random_mask(target_tokens) + elif self.cfg.noise == "full_mask": + return _full_mask(target_tokens) + elif self.cfg.noise == "no_noise": + return target_tokens + else: + raise NotImplementedError + + def build_generator(self, models, args, **unused): + # add models input to match the API for SequenceGenerator + from fairseq.iterative_refinement_generator import IterativeRefinementGenerator + + return IterativeRefinementGenerator( + self.target_dictionary, + eos_penalty=getattr(args, "iter_decode_eos_penalty", 0.0), + max_iter=getattr(args, "iter_decode_max_iter", 10), + beam_size=getattr(args, "iter_decode_with_beam", 1), + reranking=getattr(args, "iter_decode_with_external_reranker", False), + decoding_format=getattr(args, "decoding_format", None), + adaptive=not getattr(args, "iter_decode_force_max_iter", False), + retain_history=getattr(args, "retain_iter_history", False), + ) + + def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): + if constraints is not None: + # Though see Susanto et al. (ACL 2020): https://www.aclweb.org/anthology/2020.acl-main.325/ + raise NotImplementedError( + "Constrained decoding with the translation_lev task is not supported" + ) + + return LanguagePairDataset( + src_tokens, src_lengths, self.source_dictionary, append_bos=True + ) + + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): + model.train() + sample["prev_target"] = self.inject_noise(sample["target"]) + loss, sample_size, logging_output = criterion(model, sample) + if ignore_grad: + loss *= 0 + optimizer.backward(loss) + return loss, sample_size, logging_output + + def valid_step(self, sample, model, criterion): + model.eval() + with torch.no_grad(): + sample["prev_target"] = self.inject_noise(sample["target"]) + loss, sample_size, logging_output = criterion(model, sample) + return loss, sample_size, logging_output diff --git a/fairseq/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/fairseq/tasks/translation_multi_simple_epoch.py new file mode 100644 index 0000000000000000000000000000000000000000..5db36a7c79ab291319339f5df15b70234154eda2 --- /dev/null +++ b/fairseq/fairseq/tasks/translation_multi_simple_epoch.py @@ -0,0 +1,441 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import datetime +import logging +import time + +import torch +from fairseq.data import ( + FairseqDataset, + LanguagePairDataset, + ListDataset, + data_utils, + iterators, +) +from fairseq.data.multilingual.multilingual_data_manager import ( + MultilingualDatasetManager, +) +from fairseq.data.multilingual.sampling_method import SamplingMethod +from fairseq.tasks import LegacyFairseqTask, register_task +from fairseq.utils import FileContentsAction + + +### +def get_time_gap(s, e): + return ( + datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s) + ).__str__() + + +### + + +logger = logging.getLogger(__name__) + + +@register_task("translation_multi_simple_epoch") +class TranslationMultiSimpleEpochTask(LegacyFairseqTask): + """ + Translate from one (source) language to another (target) language. + + Args: + langs (List[str]): a list of languages that are being supported + dicts (Dict[str, fairseq.data.Dictionary]): mapping from supported languages to their dictionaries + training (bool): whether the task should be configured for training or not + + .. note:: + + The translation task is compatible with :mod:`fairseq-train`, + :mod:`fairseq-generate` and :mod:`fairseq-interactive`. + + The translation task provides the following additional command-line + arguments: + + .. argparse:: + :ref: fairseq.tasks.translation_parser + :prog: + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + # fmt: off + parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', + help='inference source language') + parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', + help='inference target language') + parser.add_argument('--lang-pairs', default=None, metavar='PAIRS', + help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr', + action=FileContentsAction) + parser.add_argument('--keep-inference-langtok', action='store_true', + help='keep language tokens in inference output (e.g. for analysis or debugging)') + + SamplingMethod.add_arguments(parser) + MultilingualDatasetManager.add_args(parser) + # fmt: on + + def __init__(self, args, langs, dicts, training): + super().__init__(args) + self.langs = langs + self.dicts = dicts + self.training = training + if training: + self.lang_pairs = args.lang_pairs + else: + self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)] + # eval_lang_pairs for multilingual translation is usually all of the + # lang_pairs. However for other multitask settings or when we want to + # optimize for certain languages we want to use a different subset. Thus + # the eval_lang_pairs class variable is provided for classes that extend + # this class. + self.eval_lang_pairs = self.lang_pairs + # model_lang_pairs will be used to build encoder-decoder model pairs in + # models.build_model(). This allows multitask type of sub-class can + # build models other than the input lang_pairs + self.model_lang_pairs = self.lang_pairs + self.source_langs = [d.split("-")[0] for d in self.lang_pairs] + self.target_langs = [d.split("-")[1] for d in self.lang_pairs] + self.check_dicts(self.dicts, self.source_langs, self.target_langs) + + self.sampling_method = SamplingMethod.build_sampler(args, self) + self.data_manager = MultilingualDatasetManager.setup_data_manager( + args, self.lang_pairs, langs, dicts, self.sampling_method + ) + + def check_dicts(self, dicts, source_langs, target_langs): + if self.args.source_dict is not None or self.args.target_dict is not None: + # no need to check whether the source side and target side are sharing dictionaries + return + src_dict = dicts[source_langs[0]] + tgt_dict = dicts[target_langs[0]] + for src_lang in source_langs: + assert ( + src_dict == dicts[src_lang] + ), "Diffrent dictionary are specified for different source languages; " + "TranslationMultiSimpleEpochTask only supports one shared dictionary across all source languages" + for tgt_lang in target_langs: + assert ( + tgt_dict == dicts[tgt_lang] + ), "Diffrent dictionary are specified for different target languages; " + "TranslationMultiSimpleEpochTask only supports one shared dictionary across all target languages" + + @classmethod + def setup_task(cls, args, **kwargs): + langs, dicts, training = MultilingualDatasetManager.prepare( + cls.load_dictionary, args, **kwargs + ) + return cls(args, langs, dicts, training) + + def has_sharded_data(self, split): + return self.data_manager.has_sharded_data(split) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + if split in self.datasets: + dataset = self.datasets[split] + if self.has_sharded_data(split): + if self.args.virtual_epoch_size is not None: + if dataset.load_next_shard: + shard_epoch = dataset.shard_epoch + else: + # no need to load next shard so skip loading + # also this avoid always loading from beginning of the data + return + else: + shard_epoch = epoch + else: + # estimate the shard epoch from virtual data size and virtual epoch size + shard_epoch = self.data_manager.estimate_global_pass_epoch(epoch) + logger.info(f"loading data for {split} epoch={epoch}/{shard_epoch}") + logger.info(f"mem usage: {data_utils.get_mem_usage()}") + if split in self.datasets: + del self.datasets[split] + logger.info("old dataset deleted manually") + logger.info(f"mem usage: {data_utils.get_mem_usage()}") + self.datasets[split] = self.data_manager.load_dataset( + split, + self.training, + epoch=epoch, + combine=combine, + shard_epoch=shard_epoch, + **kwargs, + ) + + def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): + if constraints is not None: + raise NotImplementedError( + "Constrained decoding with the multilingual_translation task is not supported" + ) + + src_data = ListDataset(src_tokens, src_lengths) + dataset = LanguagePairDataset(src_data, src_lengths, self.source_dictionary) + src_langtok_spec, tgt_langtok_spec = self.args.langtoks["main"] + if self.args.lang_tok_replacing_bos_eos: + dataset = self.data_manager.alter_dataset_langtok( + dataset, + src_eos=self.source_dictionary.eos(), + src_lang=self.args.source_lang, + tgt_eos=self.target_dictionary.eos(), + tgt_lang=self.args.target_lang, + src_langtok_spec=src_langtok_spec, + tgt_langtok_spec=tgt_langtok_spec, + ) + else: + dataset.src = self.data_manager.src_dataset_tranform_func( + self.args.source_lang, + self.args.target_lang, + dataset=dataset.src, + spec=src_langtok_spec, + ) + return dataset + + def build_generator( + self, + models, + args, + seq_gen_cls=None, + extra_gen_cls_kwargs=None, + ): + if not getattr(args, "keep_inference_langtok", False): + _, tgt_langtok_spec = self.args.langtoks["main"] + if tgt_langtok_spec: + tgt_lang_tok = self.data_manager.get_decoder_langtok( + self.args.target_lang, tgt_langtok_spec + ) + extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} + extra_gen_cls_kwargs["symbols_to_strip_from_output"] = {tgt_lang_tok} + + return super().build_generator( + models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs + ) + + def build_model(self, args, from_checkpoint=False): + return super().build_model(args, from_checkpoint) + + def valid_step(self, sample, model, criterion): + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + return loss, sample_size, logging_output + + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): + with torch.no_grad(): + _, tgt_langtok_spec = self.args.langtoks["main"] + if not self.args.lang_tok_replacing_bos_eos: + if prefix_tokens is None and tgt_langtok_spec: + tgt_lang_tok = self.data_manager.get_decoder_langtok( + self.args.target_lang, tgt_langtok_spec + ) + src_tokens = sample["net_input"]["src_tokens"] + bsz = src_tokens.size(0) + prefix_tokens = ( + torch.LongTensor([[tgt_lang_tok]]).expand(bsz, 1).to(src_tokens) + ) + return generator.generate( + models, + sample, + prefix_tokens=prefix_tokens, + constraints=constraints, + ) + else: + return generator.generate( + models, + sample, + prefix_tokens=prefix_tokens, + bos_token=self.data_manager.get_decoder_langtok( + self.args.target_lang, tgt_langtok_spec + ) + if tgt_langtok_spec + else self.target_dictionary.eos(), + ) + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + + def max_positions(self): + """Return the max sentence length allowed by the task.""" + return (self.args.max_source_positions, self.args.max_target_positions) + + @property + def source_dictionary(self): + return self.data_manager.get_source_dictionary(self.source_langs[0]) + + @property + def target_dictionary(self): + return self.data_manager.get_target_dictionary(self.target_langs[0]) + + def create_batch_sampler_func( + self, + max_positions, + ignore_invalid_inputs, + max_tokens, + max_sentences, + required_batch_size_multiple=1, + seed=1, + ): + def construct_batch_sampler(dataset, epoch): + splits = [ + s for s, _ in self.datasets.items() if self.datasets[s] == dataset + ] + split = splits[0] if len(splits) > 0 else None + # NEW implementation + if epoch is not None: + # initialize the dataset with the correct starting epoch + dataset.set_epoch(epoch) + + # get indices ordered by example size + start_time = time.time() + logger.info(f"start batch sampler: mem usage: {data_utils.get_mem_usage()}") + + with data_utils.numpy_seed(seed): + indices = dataset.ordered_indices() + logger.info( + f"[{split}] @batch_sampler order indices time: {get_time_gap(start_time, time.time())}" + ) + logger.info(f"mem usage: {data_utils.get_mem_usage()}") + + # filter examples that are too large + if max_positions is not None: + my_time = time.time() + indices = self.filter_indices_by_size( + indices, dataset, max_positions, ignore_invalid_inputs + ) + logger.info( + f"[{split}] @batch_sampler filter_by_size time: {get_time_gap(my_time, time.time())}" + ) + logger.info(f"mem usage: {data_utils.get_mem_usage()}") + + # create mini-batches with given size constraints + my_time = time.time() + batch_sampler = dataset.batch_by_size( + indices, + max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + ) + + logger.info( + f"[{split}] @batch_sampler batch_by_size time: {get_time_gap(my_time, time.time())}" + ) + logger.info( + f"[{split}] per epoch batch_sampler set-up time: {get_time_gap(start_time, time.time())}" + ) + logger.info(f"mem usage: {data_utils.get_mem_usage()}") + + return batch_sampler + + return construct_batch_sampler + + # we need to override get_batch_iterator because we want to reset the epoch iterator each time + def get_batch_iterator( + self, + dataset, + max_tokens=None, + max_sentences=None, + max_positions=None, + ignore_invalid_inputs=False, + required_batch_size_multiple=1, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + data_buffer_size=0, + disable_iterator_cache=False, + skip_remainder_batch=False, + grouped_shuffling=False, + update_epoch_batch_itr=False, + ): + """ + Get an iterator that yields batches of data from the given dataset. + + Args: + dataset (~fairseq.data.FairseqDataset): dataset to batch + max_tokens (int, optional): max number of tokens in each batch + (default: None). + max_sentences (int, optional): max number of sentences in each + batch (default: None). + max_positions (optional): max sentence length supported by the + model (default: None). + ignore_invalid_inputs (bool, optional): don't raise Exception for + sentences that are too long (default: False). + required_batch_size_multiple (int, optional): require batch size to + be a multiple of N (default: 1). + seed (int, optional): seed for random number generator for + reproducibility (default: 1). + num_shards (int, optional): shard the data iterator into N + shards (default: 1). + shard_id (int, optional): which shard of the data iterator to + return (default: 0). + num_workers (int, optional): how many subprocesses to use for data + loading. 0 means the data will be loaded in the main process + (default: 0). + epoch (int, optional): the epoch to start the iterator from + (default: 0). + data_buffer_size (int, optional): number of batches to + preload (default: 0). + disable_iterator_cache (bool, optional): don't cache the + EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`) + (default: False). + grouped_shuffling (bool, optional): group batches with each groups + containing num_shards batches and shuffle groups. Reduces difference + between sequence lengths among workers for batches sorted by length. + update_epoch_batch_itr (bool optional): if true then donot use the cached + batch iterator for the epoch + + Returns: + ~fairseq.iterators.EpochBatchIterator: a batched iterator over the + given dataset split + """ + # initialize the dataset with the correct starting epoch + assert isinstance(dataset, FairseqDataset) + if dataset in self.dataset_to_epoch_iter: + return self.dataset_to_epoch_iter[dataset] + if self.args.sampling_method == "RoundRobin": + batch_iter = super().get_batch_iterator( + dataset, + max_tokens=max_tokens, + max_sentences=max_sentences, + max_positions=max_positions, + ignore_invalid_inputs=ignore_invalid_inputs, + required_batch_size_multiple=required_batch_size_multiple, + seed=seed, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + epoch=epoch, + data_buffer_size=data_buffer_size, + disable_iterator_cache=disable_iterator_cache, + skip_remainder_batch=skip_remainder_batch, + update_epoch_batch_itr=update_epoch_batch_itr, + ) + self.dataset_to_epoch_iter[dataset] = batch_iter + return batch_iter + + construct_batch_sampler = self.create_batch_sampler_func( + max_positions, + ignore_invalid_inputs, + max_tokens, + max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + seed=seed, + ) + + epoch_iter = iterators.EpochBatchIterator( + dataset=dataset, + collate_fn=dataset.collater, + batch_sampler=construct_batch_sampler, + seed=seed, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + epoch=epoch, + ) + return epoch_iter diff --git a/fairseq/tests/speech_recognition/test_collaters.py b/fairseq/tests/speech_recognition/test_collaters.py new file mode 100644 index 0000000000000000000000000000000000000000..6a5029a48faea2426d7a0277655a2c7c08c1d16c --- /dev/null +++ b/fairseq/tests/speech_recognition/test_collaters.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import numpy as np +import torch +from examples.speech_recognition.data.collaters import Seq2SeqCollater + + +class TestSeq2SeqCollator(unittest.TestCase): + def test_collate(self): + + eos_idx = 1 + pad_idx = 0 + collater = Seq2SeqCollater( + feature_index=0, label_index=1, pad_index=pad_idx, eos_index=eos_idx + ) + + # 2 frames in the first sample and 3 frames in the second one + frames1 = np.array([[7, 8], [9, 10]]) + frames2 = np.array([[1, 2], [3, 4], [5, 6]]) + target1 = np.array([4, 2, 3, eos_idx]) + target2 = np.array([3, 2, eos_idx]) + sample1 = {"id": 0, "data": [frames1, target1]} + sample2 = {"id": 1, "data": [frames2, target2]} + batch = collater.collate([sample1, sample2]) + + # collate sort inputs by frame's length before creating the batch + self.assertTensorEqual(batch["id"], torch.tensor([1, 0])) + self.assertEqual(batch["ntokens"], 7) + self.assertTensorEqual( + batch["net_input"]["src_tokens"], + torch.tensor( + [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [pad_idx, pad_idx]]] + ), + ) + self.assertTensorEqual( + batch["net_input"]["prev_output_tokens"], + torch.tensor([[eos_idx, 3, 2, pad_idx], [eos_idx, 4, 2, 3]]), + ) + self.assertTensorEqual(batch["net_input"]["src_lengths"], torch.tensor([3, 2])) + self.assertTensorEqual( + batch["target"], + torch.tensor([[3, 2, eos_idx, pad_idx], [4, 2, 3, eos_idx]]), + ) + self.assertEqual(batch["nsentences"], 2) + + def assertTensorEqual(self, t1, t2): + self.assertEqual(t1.size(), t2.size(), "size mismatch") + self.assertEqual(t1.ne(t2).long().sum(), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/fairseq/tests/speech_recognition/test_data_utils.py b/fairseq/tests/speech_recognition/test_data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a72e0b66948da1349d87eafdef4c4004dd535c96 --- /dev/null +++ b/fairseq/tests/speech_recognition/test_data_utils.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import unittest + +import torch +from examples.speech_recognition.data import data_utils + + +class DataUtilsTest(unittest.TestCase): + def test_normalization(self): + sample_len1 = torch.tensor( + [ + [ + -0.7661, + -1.3889, + -2.0972, + -0.9134, + -0.7071, + -0.9765, + -0.8700, + -0.8283, + 0.7512, + 1.3211, + 2.1532, + 2.1174, + 1.2800, + 1.2633, + 1.6147, + 1.6322, + 2.0723, + 3.1522, + 3.2852, + 2.2309, + 2.5569, + 2.2183, + 2.2862, + 1.5886, + 0.8773, + 0.8725, + 1.2662, + 0.9899, + 1.1069, + 1.3926, + 1.2795, + 1.1199, + 1.1477, + 1.2687, + 1.3843, + 1.1903, + 0.8355, + 1.1367, + 1.2639, + 1.4707, + ] + ] + ) + out = data_utils.apply_mv_norm(sample_len1) + assert not torch.isnan(out).any() + assert (out == sample_len1).all() diff --git a/fairseq/tests/tasks/test_denoising.py b/fairseq/tests/tasks/test_denoising.py new file mode 100644 index 0000000000000000000000000000000000000000..5c221683523762dc0ce397505fd7095d701a6b46 --- /dev/null +++ b/fairseq/tests/tasks/test_denoising.py @@ -0,0 +1,96 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import unittest +from tempfile import TemporaryDirectory + +from fairseq import options +from fairseq.binarizer import FileBinarizer, VocabularyDatasetBinarizer +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.tasks.denoising import DenoisingTask +from tests.utils import build_vocab, make_data + + +class TestDenoising(unittest.TestCase): + def test_denoising(self): + with TemporaryDirectory() as dirname: + + # prep input file + raw_file = os.path.join(dirname, "raw") + data = make_data(out_file=raw_file) + vocab = build_vocab(data) + + # binarize + binarizer = VocabularyDatasetBinarizer(vocab, append_eos=False) + split = "train" + bin_file = os.path.join(dirname, split) + dataset_impl = "mmap" + FileBinarizer.multiprocess_dataset( + input_file=raw_file, + binarizer=binarizer, + dataset_impl=dataset_impl, + vocab_size=len(vocab), + output_prefix=bin_file, + ) + + # setup task + train_args = options.parse_args_and_arch( + options.get_training_parser(), + [ + "--task", + "denoising", + "--arch", + "bart_base", + "--seed", + "42", + "--mask-length", + "word", + "--permute-sentences", + "1", + "--rotate", + "0", + "--replace-length", + "-1", + "--mask", + "0.2", + dirname, + ], + ) + cfg = convert_namespace_to_omegaconf(train_args) + task = DenoisingTask(cfg.task, binarizer.dict) + + # load datasets + original_dataset = task._load_dataset_split(bin_file, 1, False) + task.load_dataset(split) + masked_dataset = task.dataset(split) + + iterator = task.get_batch_iterator( + dataset=masked_dataset, + max_tokens=65_536, + max_positions=4_096, + ).next_epoch_itr(shuffle=False) + mask_index = task.source_dictionary.index("") + for batch in iterator: + for sample in range(len(batch)): + net_input = batch["net_input"] + masked_src_tokens = net_input["src_tokens"][sample] + masked_src_length = net_input["src_lengths"][sample] + masked_tgt_tokens = batch["target"][sample] + + sample_id = batch["id"][sample] + original_tokens = original_dataset[sample_id] + original_tokens = original_tokens.masked_select( + masked_src_tokens[:masked_src_length] == mask_index + ) + masked_tokens = masked_tgt_tokens.masked_select( + masked_src_tokens == mask_index + ) + + assert masked_tokens.equal(original_tokens) + + +if __name__ == "__main__": + unittest.main() diff --git a/ssl-aasist-config/config.json b/ssl-aasist-config/config.json new file mode 100644 index 0000000000000000000000000000000000000000..81b4a34a26382182d55f31562ca895ff42dc66df --- /dev/null +++ b/ssl-aasist-config/config.json @@ -0,0 +1,39 @@ +{ + "filts": [ + 128, + [ + 1, + 32 + ], + [ + 32, + 32 + ], + [ + 32, + 64 + ], + [ + 64, + 64 + ] + ], + "gat_dims": [ + 64, + 32 + ], + "model_type": "ssl-aasist", + "pool_ratios": [ + 0.5, + 0.5, + 0.5, + 0.5 + ], + "temperatures": [ + 2.0, + 2.0, + 100.0, + 100.0 + ], + "transformers_version": "4.48.3" +}