PyTorch
ssl-aasist
custom_code
ash56 commited on
Commit
eaa8a4e
·
verified ·
1 Parent(s): 211c22d

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fairseq/examples/hubert/tests/sample.xlarge.L30.npy +3 -0
  2. fairseq/examples/textless_nlp/dgslm/hubert_fisher/README.md +47 -0
  3. fairseq/examples/textless_nlp/dgslm/vocoder_hifigan/README.md +47 -0
  4. fairseq/examples/textless_nlp/gslm/README.md +21 -0
  5. fairseq/examples/textless_nlp/gslm/metrics/README.md +10 -0
  6. fairseq/examples/textless_nlp/gslm/metrics/abx_metrics/dump_abx_feats.py +107 -0
  7. fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/README.md +87 -0
  8. fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/self_auto_bleu.py +201 -0
  9. fairseq/examples/textless_nlp/gslm/speech2unit/README.md +68 -0
  10. fairseq/examples/textless_nlp/gslm/speech2unit/clustering/dump_feats.py +91 -0
  11. fairseq/examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py +141 -0
  12. fairseq/examples/textless_nlp/gslm/speech2unit/clustering/utils.py +20 -0
  13. fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/cpc_feature_reader.py +204 -0
  14. fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/hubert_feature_reader.py +70 -0
  15. fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/logmel_feature_reader.py +34 -0
  16. fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/utils.py +127 -0
  17. fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/w2v2_feature_reader.py +56 -0
  18. fairseq/examples/textless_nlp/gslm/tools/README.md +25 -0
  19. fairseq/examples/textless_nlp/gslm/tools/resynthesize_speech.py +132 -0
  20. fairseq/examples/textless_nlp/gslm/ulm/README.md +72 -0
  21. fairseq/examples/textless_nlp/gslm/ulm/sample.py +174 -0
  22. fairseq/examples/textless_nlp/gslm/unit2speech/README.md +40 -0
  23. fairseq/examples/textless_nlp/gslm/unit2speech/convert_to_16k.py +56 -0
  24. fairseq/examples/textless_nlp/gslm/unit2speech/glow.py +312 -0
  25. fairseq/examples/textless_nlp/gslm/unit2speech/multiproc.py +27 -0
  26. fairseq/examples/textless_nlp/gslm/unit2speech/synthesize_audio_from_units.py +105 -0
  27. fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/__init__.py +0 -0
  28. fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/audio_processing.py +93 -0
  29. fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cleaners.py +90 -0
  30. fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cmudict.py +65 -0
  31. fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/layers.py +103 -0
  32. fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/model.py +669 -0
  33. fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/numbers.py +71 -0
  34. fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/stft.py +141 -0
  35. fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/symbols.py +18 -0
  36. fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/text.py +107 -0
  37. fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/utils.py +171 -0
  38. fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/waveglow_denoiser.py +40 -0
  39. fairseq/examples/textless_nlp/gslm/unit2speech/tts_data.py +54 -0
  40. fairseq/examples/textless_nlp/gslm/unit2speech/utils.py +55 -0
  41. fairseq/examples/textless_nlp/pgslm/README.md +318 -0
  42. fairseq/examples/textless_nlp/pgslm/data_utils.py +107 -0
  43. fairseq/examples/textless_nlp/pgslm/eval/__init__.py +4 -0
  44. fairseq/examples/textless_nlp/pgslm/eval/cont_metrics.py +730 -0
  45. fairseq/examples/textless_nlp/pgslm/generate_waveform.py +120 -0
  46. fairseq/examples/textless_nlp/pgslm/inference_dataset.py +103 -0
  47. fairseq/examples/textless_nlp/pgslm/naive_decoder.py +40 -0
  48. fairseq/examples/textless_nlp/pgslm/prepare_dataset.py +143 -0
  49. fairseq/examples/textless_nlp/pgslm/preprocess_f0.py +65 -0
  50. fairseq/examples/textless_nlp/pgslm/quantize_f0.py +94 -0
fairseq/examples/hubert/tests/sample.xlarge.L30.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbe9f0929ecd4c58786be55215d83f85787ff0d81196bc2e73414f82a8939806
3
+ size 3051712
fairseq/examples/textless_nlp/dgslm/hubert_fisher/README.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dialogue Speech-to-Unit Encoder for dGSLM: The Fisher HuBERT model
2
+ For the speech2unit encoder, we train a [HuBERT model](https://arxiv.org/pdf/2106.07447.pdf) on the [Fisher dataset](http://www.lrec-conf.org/proceedings/lrec2004/pdf/767.pdf) for 3 iterations (see [our paper](https://arxiv.org/pdf/2203.16502.pdf) for more details) and train a k-means model with 500 units on the layer 12 features of the HuBERT model.
3
+
4
+ ## Model checkpoints
5
+ The pre-trained HuBERT and k-means model checkpoints can be found here:
6
+
7
+ | Fisher HuBERT model | k-means model |
8
+ |---------------------|---------------|
9
+ |[download](https://dl.fbaipublicfiles.com/textless_nlp/dgslm/checkpoints/hubert/hubert_fisher.pt)|[download](https://dl.fbaipublicfiles.com/textless_nlp/dgslm/checkpoints/hubert/hubert_fisher_km_500.bin)|
10
+
11
+
12
+ ## Encode audio to discrete units
13
+ Below is an example command to encode a stereo dataset to discrete units using the pre-trained model checkpoints :
14
+ ```bash
15
+ for CHANNEL_ID in 1 2; do
16
+ python examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py \
17
+ --feature_type hubert \
18
+ --kmeans_model_path path/to/hubert_fisher_km_500.bin \
19
+ --acoustic_model_path path/to/hubert_fisher.pt \
20
+ --layer 12 \
21
+ --manifest_path $MANIFEST_FILE \
22
+ --out_quantized_file_path ${OUTPUT_FILE}-channel${CHANNEL_ID} \
23
+ --extension $EXTENSION \
24
+ --channel_id $CHANNEL_ID
25
+ done
26
+ ```
27
+ where MANIFEST_FILE is the output of [wav2vec manifest script](https://github.com/facebookresearch/fairseq/blob/main/examples/wav2vec/wav2vec_manifest.py), which can be obtained through the following command :
28
+ ```
29
+ python examples/wav2vec/wav2vec_manifest.py --valid-percent=0.0 $AUDIO_DIR --dest=$OUTPUT_DIR --ext=$EXTENSION
30
+ ```
31
+
32
+ Otherwise, you can encode an audio file in python interactively with the HubertTokenizer class :
33
+ ```python
34
+ # Load the Hubert tokenizer
35
+ from examples.textless_nlp.dgslm.dgslm_utils import HubertTokenizer
36
+ encoder = HubertTokenizer(
37
+ hubert_path = "/path/to/hubert_ckpt.pt",
38
+ hubert_layer = 12,
39
+ km_path = "path/to/km.bin"
40
+ )
41
+
42
+ # Encode the audio to units
43
+ path = "/path/to/stereo/audio.wav"
44
+ codes = encoder.wav2codes(path)
45
+ # > ['7 376 376 133 178 486 486 486 486 486 486 486 486 2 486',
46
+ # > '7 499 415 177 7 7 7 7 7 7 136 136 289 289 408']
47
+ ```
fairseq/examples/textless_nlp/dgslm/vocoder_hifigan/README.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dialogue Unit-to-Speech Decoder for dGSLM
2
+ For the unit2speech decoder, we train a [discrete unit-based HiFi-GAN vocoder](https://arxiv.org/pdf/2104.00355.pdf) on the [Fisher dataset](http://www.lrec-conf.org/proceedings/lrec2004/pdf/767.pdf).
3
+
4
+ ## Model checkpoint
5
+ The pre-trained model checkpoint can be found here :
6
+
7
+ | HiFi-GAN vocoder based on HuBERT Fisher Units |
8
+ |-----------------------------------------------|
9
+ |[model checkpoint](https://dl.fbaipublicfiles.com/textless_nlp/dgslm/checkpoints/hifigan/hifigan_vocoder) - [config](https://dl.fbaipublicfiles.com/textless_nlp/dgslm/checkpoints/hifigan/config.json) |
10
+
11
+ ## Decode discrete units to audio
12
+ To create waveform from discrete units, use the script `generate_stereo_waveform.py` :
13
+ ```bash
14
+ python examples/textless_nlp/dgslm/vocoder_hifigan/generate_stereo_waveform.py \
15
+ --in-file $INPUT_CODE_FILE \
16
+ --vocoder $VOCODER_PATH \
17
+ --vocoder-cfg $VOCODER_CONFIG \
18
+ --results-path $OUTPUT_DIR
19
+ ```
20
+ where INPUT_CODE_FILE is expected to have the following format :
21
+ ```
22
+ {'audio': 'file_1', 'unitA': '8 8 ... 352 352', 'unitB': '217 8 ... 8 8'}
23
+ {'audio': 'file_2', 'unitA': '5 5 ... 65 65', 'unitB': '6 35 ... 8 9'}
24
+ ...
25
+ ```
26
+
27
+ You can also use the HifiganVocoder class to generate waveform from the codes interactively :
28
+ ```python
29
+ # Load the Hifigan vocoder
30
+ from examples.textless_nlp.dgslm.dgslm_utils import HifiganVocoder
31
+ decoder = HifiganVocoder(
32
+ vocoder_path = "/path/to/hifigan_vocoder",
33
+ vocoder_cfg_path = "/path/to/config.json",
34
+ )
35
+
36
+ # Decode the units to waveform
37
+ codes = [
38
+ '7 376 376 133 178 486 486 486 486 486 486 486 486 2 486',
39
+ '7 499 415 177 7 7 7 7 7 7 136 136 289 289 408',
40
+ ]
41
+ wav = decoder.codes2wav(codes)
42
+ # > array of shape (2, 4800)
43
+
44
+ # Play the waveform
45
+ import IPython.display as ipd
46
+ ipd.Audio(wav, rate=16_000)
47
+ ```
fairseq/examples/textless_nlp/gslm/README.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generative Spoken Language Modeling
2
+
3
+ * [Paper](https://arxiv.org/abs/2102.01192)
4
+ * [Demo](https://speechbot.github.io/gslm/index.html)
5
+
6
+ We build and evaluate generative speech2speech systems using [Log Mel Filtebank](https://pytorch.org/audio/stable/compliance.kaldi.html#fbank), [Modified CPC](https://github.com/facebookresearch/CPC_audio), [HuBERT Base](https://github.com/pytorch/fairseq/tree/main/examples/hubert) and [Wav2Vec 2.0 Large](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec). Our system is composed of three components, namely, *speech2unit*, *ulm* and *unit2speech*. We explain about models and usage of these components in their respective sub-directories. See the links below.
7
+
8
+ ## Speech to Unit Model (speech2unit)
9
+ Speech to unit model is used for quantizing raw speech into learned discrete speech units. [More details](speech2unit)
10
+
11
+ ## Unit Language Model (ulm)
12
+ Unit Language Model is a generative language model trained on discrete speech units. [More details](ulm)
13
+
14
+ ## Unit to Speech Model (unit2speech)
15
+ Unit to speech model is used for synthesizing speech from discrete speech units. [More details](unit2speech)
16
+
17
+ ## Metrics
18
+ We show how to compute ASR based metrics as well as zero-shot metrics proposed in our paper [here](metrics).
19
+
20
+ ## Tools
21
+ We share two tools to resynthesize a given spoken utterance, and generate novel spoken language given a spoken prompt. [More detail](tools)
fairseq/examples/textless_nlp/gslm/metrics/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # GSLM Metrics
2
+
3
+ ## ASR Metrics
4
+ The suite of metrics here uses an ASR model to transcribe the synthesized speech into text, and then uses text-based metrics. We also use word error rate from ASR transcription itself as one of the metrics. [More details](asr_metrics)
5
+
6
+ ## ABX Metrics
7
+ We use [ABX](https://www.semanticscholar.org/paper/ABX-Discriminability-Measures-and-Applications-Schatz/13d3537228f728c1063cc83743cb118bba3367a0) to evaluate how well-separated phonetic categories are with quantized representations. [More details](abx_metrics)
8
+
9
+ ## sWUGGY and sBLIMP
10
+ We refer to [ZeroSpeech challenge](https://www.zerospeech.com/2021/track_s.html#scoring-based-metrics) for details on the sWUGGY and sBLIMP metrics.
fairseq/examples/textless_nlp/gslm/metrics/abx_metrics/dump_abx_feats.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import logging
8
+ import os
9
+
10
+ import joblib
11
+ import numpy as np
12
+
13
+ from examples.textless_nlp.gslm.speech2unit.clustering.utils import get_audio_files
14
+ from examples.textless_nlp.gslm.speech2unit.pretrained.utils import get_features
15
+
16
+ def get_logger():
17
+ log_format = "[%(asctime)s] [%(levelname)s]: %(message)s"
18
+ logging.basicConfig(format=log_format, level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+ return logger
21
+
22
+ def get_parser():
23
+ parser = argparse.ArgumentParser(
24
+ description="Quantize using K-means clustering over acoustic features."
25
+ )
26
+ parser.add_argument(
27
+ "--feature_type",
28
+ type=str,
29
+ choices=["logmel", "hubert", "w2v2", "cpc"],
30
+ default=None,
31
+ required=True,
32
+ help="Acoustic feature type",
33
+ )
34
+ parser.add_argument(
35
+ "--kmeans_model_path",
36
+ type=str,
37
+ required=True,
38
+ help="K-means model file path to use for inference",
39
+ )
40
+ parser.add_argument(
41
+ "--manifest_path",
42
+ type=str,
43
+ default=None,
44
+ help="Manifest file containing the root dir and file names",
45
+ )
46
+ parser.add_argument(
47
+ "--checkpoint_path",
48
+ type=str,
49
+ help="Pretrained model checkpoint",
50
+ )
51
+ parser.add_argument(
52
+ "--layer",
53
+ type=int,
54
+ help="The layer of the pretrained model to extract features from",
55
+ default=-1,
56
+ )
57
+ parser.add_argument(
58
+ "--out_dir_path",
59
+ required=True,
60
+ type=str,
61
+ help="File path of quantized output.",
62
+ )
63
+ parser.add_argument(
64
+ "--extension", type=str, default=".flac", help="Features file path"
65
+ )
66
+ return parser
67
+
68
+
69
+ def one_hot(feat, n_clusters):
70
+ return np.eye(n_clusters)[feat]
71
+
72
+ def main(args, logger):
73
+ # Feature extraction
74
+ logger.info(f"Extracting {args.feature_type} acoustic features...")
75
+ features_batch = get_features(
76
+ feature_type=args.feature_type,
77
+ checkpoint_path=args.checkpoint_path,
78
+ layer=args.layer,
79
+ manifest_path=args.manifest_path,
80
+ sample_pct=1.0,
81
+ flatten=False,
82
+ )
83
+ logger.info(f"Features extracted for {len(features_batch)} utterances.\n")
84
+ logger.info(f"Dimensionality of representation = {features_batch[0].shape[1]}")
85
+
86
+ logger.info(f"Loading K-means model from {args.kmeans_model_path} ...")
87
+ kmeans_model = joblib.load(open(args.kmeans_model_path, "rb"))
88
+ kmeans_model.verbose = False
89
+
90
+ _, fnames, _ = get_audio_files(args.manifest_path)
91
+
92
+ os.makedirs(args.out_dir_path, exist_ok=True)
93
+ logger.info(f"Writing quantized features to {args.out_dir_path}")
94
+ for i, feats in enumerate(features_batch):
95
+ pred = kmeans_model.predict(feats)
96
+ emb = one_hot(pred, kmeans_model.n_clusters)
97
+ base_fname = os.path.basename(fnames[i]).rstrip(args.extension)
98
+ output_path = os.path.join(args.out_dir_path, f"{base_fname}.npy")
99
+ with open(output_path, "wb") as f:
100
+ np.save(f, emb)
101
+
102
+ if __name__ == "__main__":
103
+ parser = get_parser()
104
+ args = parser.parse_args()
105
+ logger = get_logger()
106
+ logger.info(args)
107
+ main(args, logger)
fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/README.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ASR-based evaluation
2
+
3
+ Overall, the life cycle of the ASR-based evaluation for an ULM contains the following steps:
4
+ 1. Training an ULM and sampling from it [[description]](./../../ulm)
5
+ 2. Running UTS on the sampled unit sequences [[description]](./../../unit2speech)
6
+ 3. Pre-processing for the ASR (down-sampling to 16 KHz, aligning length of the generated audio with ground-truth utterances)
7
+ 4. Running ASR
8
+ 5. Calculation of the post-ASR evaluation metrics
9
+
10
+ Here we assume that you have already went throught the first two steps and focus on the rest.
11
+
12
+ ## Preprocessing
13
+ ### Down-sampling to 16KHz
14
+ The bulk conversion can be done by running
15
+ ```bash
16
+ python $FAIRSEQ_ROOT/examples/textless_nlp/gslm/unit2speech/convert_to_16k.py $UTS_OUTPUT $UTS_OUTPUT_DOWNSAMPLE
17
+ ```
18
+ where `$UTS_OUTPUT` specifies the directory with the generated audio and `$UTS_OUTPUT_DOWNSAMPLE` is the directory where downsampled audio would be saved.
19
+
20
+ ### Matching by length
21
+ This step is somewhat optional. However, if you want to compare the fluency and diversity of a generated speech utterance to that of the ground-truth speech with the same prefix, it is a good idea to force them to be of the same length.
22
+ ```bash
23
+ python $FAIRSEQ_ROOT/examples/textless_nlp/asr_metrics/cut_as.py \
24
+ --samples_dir=$UTS_OUTPUT_DOWNSAMPLE --out_dir=$UTS_OUTPUT_DOWNSAMPLE_CUT \
25
+ --prompts_description=data/ground_truth_continuation_dev.json
26
+ ```
27
+
28
+ Here `ground_truth_continuation_dev.json` is a json file with ground-truth text from LibriSpeech dev-clean, associated with some meta-data (assuming the evaluation is done on dev-clean). This file can be downloaded [[here]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/ground_truth_continuation_dev.json). A similar file for the test-clean is [[here]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/ground_truth_continuation_test.json). These files are used for the evaluation and contain texts for audio sequences that are at least 6s long.
29
+
30
+ ## Running ASR
31
+ We use a pre-trained wav2vec model to run the ASR step. We firstly need to prepare manifest files which, roughly, tell the ASR system which files we want to transcribe. You can find more details and download the `960h_scratch.pt` checkpoint
32
+ [[here]](https://github.com/pytorch/fairseq/blob/main/examples/wav2vec/README.md)). To run ASR, you would also need to
33
+ install KenLM, Flashlight decoder, and download the KenLM 4-gram English language model.
34
+
35
+ ```bash
36
+ python $FAIRSEQ_ROOT/examples/wav2vec/wav2vec_manifest.py \
37
+ $UTS_OUTPUT_DOWNSAMPLE_CUT --valid-percent 0.0 --dest $MANIFEST_DIR --ext wav
38
+ ```
39
+ where `$UTS_OUTPUT_DOWNSAMPLE_CUT` speficies the directory with the preprocessed UTS outputs and `$MANIFEST_DIR` is the output directory.
40
+
41
+ We will be running an out-of-the-box evaluation script which requires ground-truth transcripts to measure quality metrics. We are only
42
+ interested in the transcripts (and we don't have ground-truth outputs for when our ULM generated!), hence we will just generate
43
+ some dummy transcripts instead:
44
+ ```bash
45
+ cp $FAIRSEQ_ROOT/examples/textless_nlp/gslm/asr_metrics/misc/dict.ltr.txt $MANIFEST_DIR
46
+ python $FAIRSEQ_ROOT/examples/textless_nlp/gslm/asr_metrics/misc/dummy_asr_data.py --tsv=$MANIFEST_DIR/train.tsv \
47
+ --output-dir=$MANIFEST_DIR
48
+ ```
49
+
50
+ Now we are ready for running ASR:
51
+ ```
52
+ mkdir -p asr
53
+ python $FAIRSEQ_ROOT/examples/speech_recognition/infer.py \
54
+ $MANIFEST_DIR \
55
+ --task audio_pretraining --nbest 1 --path 960h_scratch.pt \
56
+ --gen-subset=train --results-path $PATH_TO_ASR_OUTPUT \
57
+ --w2l-decoder kenlm --lm-model 4-gram.bin \
58
+ --lexicon librispeech/lexicon_ltr.lst --word-score -1 \
59
+ --sil-weight 0 --lm-weight 2 --criterion ctc --labels ltr --max-tokens 300000 --remove-bpe letter
60
+ ```
61
+ where `lexicon_ltr.lst` is the LibriSpeech lexicon and `$PATH_TO_ASR_OUTPUT` is the output directory (can be downloaded [[here]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/lexicon_ltr.lst)).
62
+
63
+ ## Evaluation metrics
64
+ We run evaluation on the 1_000 shortest sequences that are at least 6s long. To filter those from the ASR transcript, we additionally provide each metric script with the paths to the manifest and `ground_truth_continuation_*` files.
65
+
66
+ ### Perplexity (PPX)
67
+ To get a PPX metric estimate on an ASR transcript, you need to run the following command:
68
+ ```bash
69
+ python ppx.py $PATH_TO_ASR_OUTPUT/hypo.word-960h_scratch.pt-train.txt --cut-tail\
70
+ --manifest=$MANIFEST_DIR/train.tsv --prompts-description=data/ground_truth_continuation_dev.json
71
+ ```
72
+ where `--cut-tail` tells the script to ignore the last token on each line (ASR puts the sequence ID there).
73
+
74
+ ### Self- and Auto-BLEU
75
+ ```bash
76
+ python self_bleu.py $PATH_TO_ASR_OUTPUT/hypo.word-960h_scratch.pt-train.txt --cut-tail \
77
+ --manifest=$MANIFEST_DIR/train.tsv --prompts-description=data/ground_truth_continuation_dev.json
78
+ ```
79
+
80
+ ### Continuation-BLEU
81
+ ```bash
82
+ python continuation_eval.py --asr-transcript $PATH_TO_ASR_OUTPUT/hypo.word-960h_scratch.pt-train.txt \
83
+ --manifest=$MANIFEST_DIR/train.tsv --prompts-description=data/ground_truth_continuation_dev.json
84
+ ```
85
+
86
+ ### AUC
87
+ Based on the metrics calculated above, we can estimate the AUC of the perplexity/diversity trade-off. We provide an illustration in a [Colab notebook](https://colab.research.google.com/drive/1pVPfOVax_PU3MkYdHRSsa-SI8GBUldNt?usp=sharing).
fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/self_auto_bleu.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import nltk
8
+ from misc.bleu_utils import sentence_bleu
9
+ import warnings
10
+
11
+
12
+ def get_target_sequences(manifest, ground_truth, to_take=1000):
13
+ import json
14
+ import pathlib
15
+
16
+ with open(ground_truth, 'r') as fin:
17
+ original_continuations = json.loads(fin.read())
18
+
19
+ sequence2length = [(k, v[0]) for k, v in original_continuations.items()]
20
+ assert all(float(v) >= 6.0 for (_, v) in sequence2length) # 6 seconds
21
+
22
+ sequence2length.sort(key=lambda x: x[1])
23
+ to_take_sequences = set(v[0] for v in sequence2length[:to_take])
24
+ to_take_ids = []
25
+
26
+ with open(manifest, 'r') as f:
27
+ f.readline()
28
+
29
+ for i, line in enumerate(f.readlines()):
30
+ seq_id = line.split()[0]
31
+ seq_id = pathlib.Path(seq_id).name.split('__')[0]
32
+
33
+ if seq_id in to_take_sequences:
34
+ to_take_ids.append(i)
35
+
36
+ print(f'Took {len(to_take_ids)} ids')
37
+ return set(to_take_ids)
38
+
39
+
40
+ def get_args():
41
+ import argparse
42
+
43
+ parser = argparse.ArgumentParser()
44
+ parser.add_argument('--asr-transcript', type=str,
45
+ help='Path to the transcript file.')
46
+
47
+ parser.add_argument('--manifest', required=True)
48
+ parser.add_argument('--prompts-description', required=True)
49
+
50
+ parser.add_argument('--cut-id', action='store_true',
51
+ help='Whether cut the first token (typically a seq id)')
52
+ parser.add_argument('--cut-tail', action='store_true',
53
+ help='Whether cut the last token (typically a speaker id)')
54
+ parser.add_argument('--debug', action='store_true')
55
+
56
+ args = parser.parse_args()
57
+
58
+ return args
59
+
60
+
61
+ def get_self_bleu(utterances, averaging_mode, weights):
62
+ self_bleu = []
63
+
64
+ for i in range(len(utterances)):
65
+ hypo = utterances[i]
66
+ rest = utterances[:i] + utterances[i+1:]
67
+
68
+ self_bleu.append(sentence_bleu(rest, hypo, weights,
69
+ no_length_penalty=True, averaging_mode=averaging_mode))
70
+
71
+ return self_bleu
72
+
73
+
74
+ def get_self_bleu2_arithmetic(utterances):
75
+ weights = (0.5, 0.5) # equal weight for unigrams and bigrams
76
+ return get_self_bleu(utterances, averaging_mode='arithmetic', weights=weights)
77
+
78
+
79
+ def get_self_bleu2_geometric(utterances):
80
+ weights = (0.5, 0.5)
81
+ return get_self_bleu(utterances, averaging_mode='geometric', weights=weights)
82
+
83
+
84
+ def get_auto_bleu2_arithmetic(utterances):
85
+ weights = (0.5, 0.5)
86
+ return [auto_bleu(u, mean_mode='arithmetic', weights=weights) for u in utterances]
87
+
88
+
89
+ def get_auto_bleu2_geometric(utterances):
90
+ weights = (0.5, 0.5)
91
+ return [auto_bleu(u, mean_mode='geometric', weights=weights) for u in utterances]
92
+
93
+
94
+ def get_auto_bleu3_geometric(utterances):
95
+ weights = (1./3, 1./3, 1./3)
96
+ return [auto_bleu(u, mean_mode='geometric', weights=weights) for u in utterances]
97
+
98
+
99
+ def get_auto_bleu3_arithmetic(utterances):
100
+ weights = (1./3, 1./3, 1./3)
101
+ return [auto_bleu(u, mean_mode='arithmetic', weights=weights) for u in utterances]
102
+
103
+
104
+ def get_self_bleu3_arithmetic(utterances):
105
+ weights = (1./3, 1./3, 1./3)
106
+ return get_self_bleu(utterances, averaging_mode='arithmetic', weights=weights)
107
+
108
+
109
+ def get_self_bleu3_geometric(utterances):
110
+ weights = (1./3, 1./3, 1./3)
111
+ return get_self_bleu(utterances, averaging_mode='geometric', weights=weights)
112
+
113
+
114
+ def auto_bleu(sentence, weights, mean_mode='arithmetic'):
115
+ if len(sentence) <= 1:
116
+ return 0
117
+
118
+ N = len(weights)
119
+
120
+ bleu_n = np.zeros([N])
121
+ for n in range(N):
122
+ targ_ngrams = list(nltk.ngrams(sentence, n+1))
123
+ for p in range(len(targ_ngrams)):
124
+ left = sentence[:p]
125
+ right = sentence[(p+n+1):]
126
+ rest_ngrams = list(nltk.ngrams(left, n+1)) + \
127
+ list(nltk.ngrams(right, n+1))
128
+ # compute the nb of matching ngrams
129
+ bleu_n[n] += targ_ngrams[p] in rest_ngrams
130
+ bleu_n[n] /= len(targ_ngrams) # average them to get a proportion
131
+
132
+ weights = np.array(weights)
133
+ if mean_mode == 'arithmetic':
134
+ return (bleu_n * weights).sum()
135
+ elif mean_mode == 'geometric':
136
+ return (bleu_n ** weights).prod()
137
+ else:
138
+ raise ValueError(f'Unknown agggregation mode {mean_mode}')
139
+
140
+
141
+ def main():
142
+ from multiprocessing import Pool
143
+
144
+ args = get_args()
145
+ target_ids = get_target_sequences(args.manifest, args.prompts_description)
146
+
147
+ with open(args.asr_transcript, 'r') as fin:
148
+ lines = fin.readlines()
149
+
150
+ terms = [x.strip().split() for x in lines]
151
+ filtered = []
152
+ for term in terms:
153
+ line_id = int(term[-1].split('-')[1][:-1])
154
+ if line_id in target_ids:
155
+ filtered.append(term)
156
+ terms = filtered
157
+
158
+ if args.cut_id:
159
+ terms = [x[1:] for x in terms]
160
+ if args.cut_tail:
161
+ terms = [x[:-1] for x in terms]
162
+
163
+ if args.debug:
164
+ terms = terms[:10]
165
+
166
+ tasks = [
167
+ ('Self-BLEU2-arithmetic', get_self_bleu2_arithmetic),
168
+ ('Self-BLEU2-geometric', get_self_bleu2_geometric),
169
+ ('Auto-BLEU2-arithmetic', get_auto_bleu2_arithmetic),
170
+ ('Auto-BLEU2-geometric', get_auto_bleu2_geometric),
171
+
172
+ ('Self-BLEU3-arithmetic', get_self_bleu3_arithmetic),
173
+ ('Self-BLEU3-geometric', get_self_bleu3_geometric),
174
+ ('Auto-BLEU3-arithmetic', get_auto_bleu3_arithmetic),
175
+ ('Auto-BLEU3-geometric', get_auto_bleu3_geometric),
176
+ ]
177
+
178
+ n_processes = min(16, len(tasks))
179
+ with Pool(n_processes) as pool:
180
+ metrics = pool.map(run_f, [(t[1], terms) for t in tasks])
181
+
182
+ for (metric_name, _), metric in zip(tasks, metrics):
183
+ metric, sem = np.mean(metric), np.std(metric) / np.sqrt(len(metric))
184
+
185
+ metric, sem = [
186
+ round(100 * x, 2) for x in [metric, sem]
187
+ ]
188
+
189
+ print(f'{metric_name} {metric} +- {sem}')
190
+
191
+
192
+ def run_f(task_params):
193
+ f, terms = task_params
194
+ return f(terms)
195
+
196
+
197
+ if __name__ == '__main__':
198
+ # NLTK produces warnings
199
+ warnings.filterwarnings("ignore")
200
+
201
+ main()
fairseq/examples/textless_nlp/gslm/speech2unit/README.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Speech to Unit Model (speech2unit)
2
+
3
+ ## Acoustic Model
4
+ For quantizing speech we learn a K-means clustering over acoustic representations for which we either use Log-Mel Filterbank or pretrained acoustic representation models. For using pretrained models, please download from their respective locations linked below.
5
+ * [Modified CPC](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/cpc_big_ll6kh_top_ctc.pt)
6
+ * [HuBERT-Base](https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt)
7
+ * [Wav2Vec 2.0-Base](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_new.pt)
8
+
9
+ ## Quantization Model
10
+ You can download pretrained quantized model from the list below.
11
+
12
+ K-Means Model | Download Link
13
+ |-|-
14
+ Log Mel Filterbank + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/km50/km.bin)
15
+ Log Mel Filterbank + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/km100/km.bin)
16
+ Log Mel Filterbank + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/km200/km.bin)
17
+ Modified CPC + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/km50/km.bin)
18
+ Modified CPC + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/km100/km.bin)
19
+ Modified CPC + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/km200/km.bin)
20
+ HuBERT Base + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/km50/km.bin)
21
+ HuBERT Base + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/km100/km.bin)
22
+ HuBERT Base + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/km200/km.bin)
23
+ wav2vec 2.0 Large + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/km50/km.bin)
24
+ wav2vec 2.0 Large + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/km100/km.bin)
25
+ wav2vec 2.0 Large + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/km200/km.bin)
26
+
27
+ ### Quantization
28
+ For quantizing speech with a given acoustic representation, please follow the steps below.
29
+ 1. Learn K-means clustering model
30
+ ```
31
+ N_CLUSTERS=<number_of_clusters_used_for_kmeans>
32
+ TYPE=<one_of_logmel/cpc/hubert/w2v2>
33
+ CKPT_PATH=<path_of_pretrained_acoustic_model>
34
+ LAYER=<layer_of_acoustic_model_to_extract_features_from>
35
+ MANIFEST=<tab_separated_manifest_of_audio_files_for_training_kmeans>
36
+ KM_MODEL_PATH=<output_path_of_the_kmeans_model>
37
+
38
+ PYTHONPATH=. python examples/textless_nlp/gslm/speech2unit/clustering/cluster_kmeans.py \
39
+ --num_clusters $N_CLUSTERS \
40
+ --feature_type $TYPE \
41
+ --checkpoint_path $CKPT_PATH \
42
+ --layer $LAYER \
43
+ --manifest_path $MANIFEST \
44
+ --out_kmeans_model_path $KM_MODEL_PATH
45
+ ```
46
+ 2. Quantize using the learned clusters
47
+ ```
48
+ MANIFEST=<tab_separated_manifest_of_audio_files_to_quantize>
49
+ OUT_QUANTIZED_FILE=<output_quantized_audio_file_path>
50
+
51
+ python examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py \
52
+ --feature_type $TYPE \
53
+ --kmeans_model_path $KM_MODEL_PATH \
54
+ --acoustic_model_path $CKPT_PATH \
55
+ --layer $LAYER \
56
+ --manifest_path $MANIFEST \
57
+ --out_quantized_file_path $OUT_QUANTIZED_FILE \
58
+ --extension ".flac"
59
+ ```
60
+
61
+ Note about the manifest file is a file with paths and length of input audio files. The format of the file is as follows:
62
+ ```
63
+ <path_of_root_directory_containing_audio_files>
64
+ <relative_path_of_audio_file_1>\t<number_of_frames_1>
65
+ <relative_path_of_audio_file_2>\t<number_of_frames_1>
66
+ ...
67
+ ```
68
+
fairseq/examples/textless_nlp/gslm/speech2unit/clustering/dump_feats.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import logging
8
+
9
+ from examples.textless_nlp.gslm.speech2unit.pretrained.utils import (
10
+ get_and_dump_features,
11
+ )
12
+
13
+
14
+ def get_parser():
15
+ parser = argparse.ArgumentParser(
16
+ description="Compute and dump log mel fbank features."
17
+ )
18
+ parser.add_argument(
19
+ "--feature_type",
20
+ type=str,
21
+ choices=["logmel", "hubert", "w2v2", "cpc"],
22
+ default=None,
23
+ help="Acoustic feature type",
24
+ )
25
+ parser.add_argument(
26
+ "--manifest_path",
27
+ type=str,
28
+ default=None,
29
+ help="Manifest file containing the root dir and file names",
30
+ )
31
+ parser.add_argument(
32
+ "--out_features_path",
33
+ type=str,
34
+ default=None,
35
+ help="Features file path to write to",
36
+ )
37
+ parser.add_argument(
38
+ "--checkpoint_path",
39
+ type=str,
40
+ help="Pretrained acoustic model checkpoint",
41
+ )
42
+ parser.add_argument(
43
+ "--layer",
44
+ type=int,
45
+ help="The layer of the pretrained model to extract features from",
46
+ default=-1,
47
+ )
48
+ parser.add_argument(
49
+ "--sample_pct",
50
+ type=float,
51
+ help="Percent data to use for K-means training",
52
+ default=0.1,
53
+ )
54
+ parser.add_argument(
55
+ "--out_features_path",
56
+ type=str,
57
+ help="Path to save log mel fbank features",
58
+ )
59
+ return parser
60
+
61
+
62
+ def get_logger():
63
+ log_format = "[%(asctime)s] [%(levelname)s]: %(message)s"
64
+ logging.basicConfig(format=log_format, level=logging.INFO)
65
+ logger = logging.getLogger(__name__)
66
+ return logger
67
+
68
+
69
+ if __name__ == "__main__":
70
+ """
71
+ Example command:
72
+ python ~/speechbot/clustering/dump_logmelfank_feats.py \
73
+ --manifest_path /checkpoint/kushall/data/LJSpeech-1.1/asr_input_wavs_16k/train.tsv
74
+ --out_features_path /checkpoint/kushall/experiments/speechbot/logmelfbank/features/ljspeech/train.npy
75
+ """
76
+ parser = get_parser()
77
+ args = parser.parse_args()
78
+ logger = get_logger()
79
+ logger.info(args)
80
+
81
+ logger.info(f"Extracting {args.feature_type} acoustic features...")
82
+ get_and_dump_features(
83
+ feature_type=args.feature_type,
84
+ checkpoint_path=args.checkpoint_path,
85
+ layer=args.layer,
86
+ manifest_path=args.manifest_path,
87
+ sample_pct=args.sample_pct,
88
+ flatten=True,
89
+ out_features_path=args.out_features_path,
90
+ )
91
+ logger.info(f"Saved extracted features at {args.out_features_path}")
fairseq/examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import logging
8
+ import os
9
+
10
+ import numpy as np
11
+
12
+ import joblib
13
+ from examples.textless_nlp.gslm.speech2unit.clustering.utils import (
14
+ get_audio_files,
15
+ )
16
+ from examples.textless_nlp.gslm.speech2unit.pretrained.utils import (
17
+ get_features,
18
+ )
19
+
20
+
21
+ def get_logger():
22
+ log_format = "[%(asctime)s] [%(levelname)s]: %(message)s"
23
+ logging.basicConfig(format=log_format, level=logging.INFO)
24
+ logger = logging.getLogger(__name__)
25
+ return logger
26
+
27
+
28
+ def get_parser():
29
+ parser = argparse.ArgumentParser(
30
+ description="Quantize using K-means clustering over acoustic features."
31
+ )
32
+ parser.add_argument(
33
+ "--feature_type",
34
+ type=str,
35
+ choices=["logmel", "hubert", "w2v2", "cpc"],
36
+ default=None,
37
+ required=True,
38
+ help="Acoustic feature type",
39
+ )
40
+ parser.add_argument(
41
+ "--acoustic_model_path",
42
+ type=str,
43
+ help="Pretrained acoustic model checkpoint"
44
+ )
45
+ parser.add_argument(
46
+ "--layer",
47
+ type=int,
48
+ help="The layer of the pretrained model to extract features from",
49
+ default=-1,
50
+ )
51
+ parser.add_argument(
52
+ "--kmeans_model_path",
53
+ type=str,
54
+ required=True,
55
+ help="K-means model file path to use for inference",
56
+ )
57
+ parser.add_argument(
58
+ "--features_path",
59
+ type=str,
60
+ default=None,
61
+ help="Features file path. You don't need to enter acoustic model details if you have dumped features",
62
+ )
63
+ parser.add_argument(
64
+ "--manifest_path",
65
+ type=str,
66
+ default=None,
67
+ help="Manifest file containing the root dir and file names",
68
+ )
69
+ parser.add_argument(
70
+ "--out_quantized_file_path",
71
+ required=True,
72
+ type=str,
73
+ help="File path of quantized output.",
74
+ )
75
+ parser.add_argument(
76
+ "--extension", type=str, default=".flac", help="Features file path"
77
+ )
78
+ parser.add_argument(
79
+ "--channel_id",
80
+ choices=['1', '2'],
81
+ help="The audio channel to extract the units in case of stereo file.",
82
+ default=None,
83
+ )
84
+ parser.add_argument(
85
+ "--hide-fname", action='store_true',
86
+ help="Hide file names in the output file."
87
+ )
88
+ return parser
89
+
90
+
91
+ def main(args, logger):
92
+ # Feature extraction
93
+ if args.features_path is not None:
94
+ logger.info(f"Loading acoustic features from {args.features_path}...")
95
+ features_batch = np.load(args.features_path)
96
+ else:
97
+ logger.info(f"Extracting {args.feature_type} acoustic features...")
98
+ features_batch = get_features(
99
+ feature_type=args.feature_type,
100
+ checkpoint_path=args.acoustic_model_path,
101
+ layer=args.layer,
102
+ manifest_path=args.manifest_path,
103
+ sample_pct=1.0,
104
+ flatten=False,
105
+ channel_id=int(args.channel_id) if args.channel_id else None,
106
+ )
107
+ logger.info(
108
+ f"Features extracted for {len(features_batch)} utterances.\n"
109
+ )
110
+ logger.info(
111
+ f"Dimensionality of representation = {features_batch[0].shape[1]}"
112
+ )
113
+
114
+ # K-means model
115
+ logger.info(f"Loading K-means model from {args.kmeans_model_path} ...")
116
+ kmeans_model = joblib.load(open(args.kmeans_model_path, "rb"))
117
+ kmeans_model.verbose = False
118
+
119
+ _, fnames, _ = get_audio_files(args.manifest_path)
120
+
121
+ os.makedirs(os.path.dirname(args.out_quantized_file_path), exist_ok=True)
122
+ print(f"Writing quantized predictions to {args.out_quantized_file_path}")
123
+ with open(args.out_quantized_file_path, "w") as fout:
124
+ for i, feats in enumerate(features_batch):
125
+ pred = kmeans_model.predict(feats)
126
+ pred_str = " ".join(str(p) for p in pred)
127
+ base_fname = os.path.basename(fnames[i]).rstrip('.'+args.extension.lstrip('.'))
128
+ if args.channel_id is not None:
129
+ base_fname = base_fname+f'-channel{args.channel_id}'
130
+ if not args.hide_fname:
131
+ fout.write(f"{base_fname}|{pred_str}\n")
132
+ else:
133
+ fout.write(f"{pred_str}\n")
134
+
135
+
136
+ if __name__ == "__main__":
137
+ parser = get_parser()
138
+ args = parser.parse_args()
139
+ logger = get_logger()
140
+ logger.info(args)
141
+ main(args, logger)
fairseq/examples/textless_nlp/gslm/speech2unit/clustering/utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import List, Tuple
7
+
8
+
9
+ def get_audio_files(manifest_path: str) -> Tuple[str, List[str], List[int]]:
10
+ fnames, sizes = [], []
11
+ with open(manifest_path, "r") as f:
12
+ root_dir = f.readline().strip()
13
+ for line in f:
14
+ items = line.strip().split("\t")
15
+ assert (
16
+ len(items) == 2
17
+ ), f"File must have two columns separated by tab. Got {line}"
18
+ fnames.append(items[0])
19
+ sizes.append(int(items[1]))
20
+ return root_dir, fnames, sizes
fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/cpc_feature_reader.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import soundfile as sf
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class CpcFeatureReader:
8
+ """
9
+ Wrapper class to run inference on CPC model.
10
+ Helps extract features for a given audio file.
11
+ """
12
+
13
+ def __init__(
14
+ self,
15
+ checkpoint_path,
16
+ layer,
17
+ use_encoder_layer=False,
18
+ norm_features=False,
19
+ sample_rate=16000,
20
+ max_chunk=64000,
21
+ use_cuda=True,
22
+ ):
23
+ self.model = load_cpc_model(checkpoint_path, layer).eval()
24
+ self.sample_rate = sample_rate
25
+ self.max_chunk = max_chunk
26
+ self.norm_features = norm_features
27
+ self.use_encoder_layer = use_encoder_layer
28
+ self.use_cuda = use_cuda
29
+ if self.use_cuda:
30
+ self.model.cuda()
31
+
32
+ def read_audio(self, path, ref_len=None, channel_id=None):
33
+ wav, sr = sf.read(path)
34
+ if channel_id is not None:
35
+ assert wav.ndim == 2, \
36
+ f"Expected stereo input when channel_id is given ({path})"
37
+ assert channel_id in [1, 2], \
38
+ "channel_id is expected to be in [1, 2]"
39
+ wav = wav[:, channel_id-1]
40
+ if wav.ndim == 2:
41
+ wav = wav.mean(-1)
42
+ assert wav.ndim == 1, wav.ndim
43
+ assert sr == self.sample_rate, sr
44
+ if ref_len is not None and abs(ref_len - len(wav)) > 160:
45
+ print(f"ref {ref_len} != read {len(wav)} ({path})")
46
+ return wav
47
+
48
+ def get_feats(self, file_path, ref_len=None, channel_id=None):
49
+ x = self.read_audio(file_path, ref_len, channel_id)
50
+ # Inspired from CPC_audio feature_loader.py
51
+ with torch.no_grad():
52
+ x = torch.from_numpy(x).float()
53
+ if self.use_cuda:
54
+ x = x.cuda()
55
+ x = x.view(1, 1, -1)
56
+ size = x.size(2)
57
+ feat = []
58
+ start = 0
59
+ while start < size:
60
+ if start + self.max_chunk > size:
61
+ break
62
+ x_chunk = x[..., start : start + self.max_chunk]
63
+ feat_chunk = self.model.extract_features(
64
+ source=x_chunk,
65
+ get_encoded=self.use_encoder_layer,
66
+ norm_output=self.norm_features,
67
+ )
68
+ feat.append(feat_chunk)
69
+ start += self.max_chunk
70
+
71
+ if start < size:
72
+ x_chunk = x[:, -self.max_chunk :]
73
+ feat_chunk = self.model.extract_features(
74
+ source=x_chunk,
75
+ get_encoded=self.use_encoder_layer,
76
+ norm_output=self.norm_features,
77
+ )
78
+ df = x_chunk.size(2) // feat_chunk.size(1)
79
+ delta = (size - start) // df
80
+ feat.append(feat_chunk[:, -delta:])
81
+ return torch.cat(feat, 1).squeeze(0)
82
+
83
+
84
+ def load_cpc_model(checkpoint_path, layer=None):
85
+ state_dict = torch.load(checkpoint_path)
86
+ weights = state_dict["weights"]
87
+ config = state_dict["config"]
88
+ if layer is not None:
89
+ config["nLevelsGRU"] = layer
90
+
91
+ encoder = CPCEncoder(config["hiddenEncoder"])
92
+ ar_net = CPCAR(
93
+ config["hiddenEncoder"], config["hiddenGar"], False, config["nLevelsGRU"]
94
+ )
95
+
96
+ model = CPCModel(encoder, ar_net)
97
+ model.load_state_dict(weights, strict=False)
98
+ model.config = config
99
+
100
+ return model
101
+
102
+
103
+ class ChannelNorm(nn.Module):
104
+ def __init__(self, num_features, epsilon=1e-05, affine=True):
105
+ super(ChannelNorm, self).__init__()
106
+ if affine:
107
+ self.weight = nn.parameter.Parameter(torch.Tensor(1, num_features, 1))
108
+ self.bias = nn.parameter.Parameter(torch.Tensor(1, num_features, 1))
109
+ else:
110
+ self.weight = None
111
+ self.bias = None
112
+ self.epsilon = epsilon
113
+ self.p = 0
114
+ self.affine = affine
115
+ self.reset_parameters()
116
+
117
+ def reset_parameters(self):
118
+ if self.affine:
119
+ torch.nn.init.ones_(self.weight)
120
+ torch.nn.init.zeros_(self.bias)
121
+
122
+ def forward(self, x):
123
+ cum_mean = x.mean(dim=1, keepdim=True)
124
+ cum_var = x.var(dim=1, keepdim=True)
125
+ x = (x - cum_mean) * torch.rsqrt(cum_var + self.epsilon)
126
+ if self.weight is not None:
127
+ x = x * self.weight + self.bias
128
+ return x
129
+
130
+
131
+ class CPCEncoder(nn.Module):
132
+ def __init__(self, hidden_dim=512):
133
+ super(CPCEncoder, self).__init__()
134
+ self.conv0 = nn.Conv1d(1, hidden_dim, 10, stride=5, padding=3)
135
+ self.batchNorm0 = ChannelNorm(hidden_dim)
136
+ self.conv1 = nn.Conv1d(hidden_dim, hidden_dim, 8, stride=4, padding=2)
137
+ self.batchNorm1 = ChannelNorm(hidden_dim)
138
+ self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, 4, stride=2, padding=1)
139
+ self.batchNorm2 = ChannelNorm(hidden_dim)
140
+ self.conv3 = nn.Conv1d(hidden_dim, hidden_dim, 4, stride=2, padding=1)
141
+ self.batchNorm3 = ChannelNorm(hidden_dim)
142
+ self.conv4 = nn.Conv1d(hidden_dim, hidden_dim, 4, stride=2, padding=1)
143
+ self.batchNorm4 = ChannelNorm(hidden_dim)
144
+ self.DOWNSAMPLING = 160
145
+
146
+ def get_output_dim(self):
147
+ return self.conv4.out_channels
148
+
149
+ def forward(self, x):
150
+ x = F.relu(self.batchNorm0(self.conv0(x)))
151
+ x = F.relu(self.batchNorm1(self.conv1(x)))
152
+ x = F.relu(self.batchNorm2(self.conv2(x)))
153
+ x = F.relu(self.batchNorm3(self.conv3(x)))
154
+ x = F.relu(self.batchNorm4(self.conv4(x)))
155
+ return x
156
+
157
+
158
+ class CPCAR(nn.Module):
159
+ def __init__(self, dim_encoded, dim_output, keep_hidden, num_layers):
160
+ super(CPCAR, self).__init__()
161
+ self.baseNet = nn.LSTM(
162
+ dim_encoded, dim_output, num_layers=num_layers, batch_first=True
163
+ )
164
+ self.hidden = None
165
+ self.keep_hidden = keep_hidden
166
+
167
+ def get_output_dim(self):
168
+ return self.baseNet.hidden_size
169
+
170
+ def forward(self, x):
171
+ try:
172
+ self.baseNet.flatten_parameters()
173
+ except RuntimeError:
174
+ pass
175
+ x, h = self.baseNet(x, self.hidden)
176
+ if self.keep_hidden:
177
+ if isinstance(h, tuple):
178
+ self.hidden = tuple(x.detach() for x in h)
179
+ else:
180
+ self.hidden = h.detach()
181
+ return x
182
+
183
+
184
+ class CPCModel(nn.Module):
185
+ def __init__(self, encoder, ar_net):
186
+ super(CPCModel, self).__init__()
187
+ self.gEncoder = encoder
188
+ self.gAR = ar_net
189
+ self.config = None
190
+
191
+ def forward(self, x, label):
192
+ encoded = self.gEncoder(x).permute(0, 2, 1)
193
+ cpc_feature = self.gAR(encoded)
194
+ return cpc_feature, encoded, label
195
+
196
+ def extract_features(self, source, get_encoded=False, norm_output=False):
197
+ cpc_feature, encoded, _ = self.forward(source, None)
198
+ if get_encoded:
199
+ cpc_feature = encoded
200
+ if norm_output:
201
+ mean = cpc_feature.mean(dim=1, keepdim=True)
202
+ var = cpc_feature.var(dim=1, keepdim=True)
203
+ cpc_feature = (cpc_feature - mean) / torch.sqrt(var + 1e-08)
204
+ return cpc_feature
fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/hubert_feature_reader.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import fairseq
8
+ import soundfile as sf
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class HubertFeatureReader:
13
+ """
14
+ Wrapper class to run inference on HuBERT model.
15
+ Helps extract features for a given audio file.
16
+ """
17
+
18
+ def __init__(self, checkpoint_path, layer, max_chunk=1600000, use_cuda=True):
19
+ (
20
+ model,
21
+ cfg,
22
+ task,
23
+ ) = fairseq.checkpoint_utils.load_model_ensemble_and_task(
24
+ [checkpoint_path]
25
+ )
26
+ self.model = model[0].eval()
27
+ self.task = task
28
+ self.layer = layer
29
+ self.max_chunk = max_chunk
30
+ self.use_cuda = use_cuda
31
+ if self.use_cuda:
32
+ self.model.cuda()
33
+
34
+ def read_audio(self, path, ref_len=None, channel_id=None):
35
+ wav, sr = sf.read(path)
36
+ if channel_id is not None:
37
+ assert wav.ndim == 2, \
38
+ f"Expected stereo input when channel_id is given ({path})"
39
+ assert channel_id in [1, 2], \
40
+ "channel_id is expected to be in [1, 2]"
41
+ wav = wav[:, channel_id-1]
42
+ if wav.ndim == 2:
43
+ wav = wav.mean(-1)
44
+ assert wav.ndim == 1, wav.ndim
45
+ assert sr == self.task.cfg.sample_rate, sr
46
+ if ref_len is not None and abs(ref_len - len(wav)) > 160:
47
+ print(f"ref {ref_len} != read {len(wav)} ({path})")
48
+ return wav
49
+
50
+ def get_feats(self, file_path, ref_len=None, channel_id=None):
51
+ x = self.read_audio(file_path, ref_len, channel_id)
52
+ with torch.no_grad():
53
+ x = torch.from_numpy(x).float()
54
+ if self.use_cuda:
55
+ x = x.cuda()
56
+ if self.task.cfg.normalize:
57
+ x = F.layer_norm(x, x.shape)
58
+ x = x.view(1, -1)
59
+
60
+ feat = []
61
+ for start in range(0, x.size(1), self.max_chunk):
62
+ x_chunk = x[:, start: start + self.max_chunk]
63
+ feat_chunk, _ = self.model.extract_features(
64
+ source=x_chunk,
65
+ padding_mask=None,
66
+ mask=False,
67
+ output_layer=self.layer,
68
+ )
69
+ feat.append(feat_chunk)
70
+ return torch.cat(feat, 1).squeeze(0)
fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/logmel_feature_reader.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import soundfile as sf
7
+ import torch
8
+ import torchaudio.compliance.kaldi as kaldi
9
+
10
+
11
+ class LogMelFeatureReader:
12
+ """
13
+ Wrapper class to run inference on HuBERT model.
14
+ Helps extract features for a given audio file.
15
+ """
16
+
17
+ def __init__(self, *args, **kwargs):
18
+ self.num_mel_bins = kwargs.get("num_mel_bins", 80)
19
+ self.frame_length = kwargs.get("frame_length", 25.0)
20
+
21
+ def get_feats(self, file_path, channel_id=None):
22
+ wav, sr = sf.read(file_path)
23
+ if channel_id is not None:
24
+ assert wav.ndim == 2, \
25
+ f"Expected stereo input when channel_id is given ({file_path})"
26
+ wav = wav[:, channel_id-1]
27
+ feats = torch.from_numpy(wav).float()
28
+ feats = kaldi.fbank(
29
+ feats.unsqueeze(0),
30
+ num_mel_bins=self.num_mel_bins,
31
+ frame_length=self.frame_length,
32
+ sample_frequency=sr,
33
+ )
34
+ return feats
fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import gc
7
+ import os
8
+ import random
9
+ import shutil
10
+ import numpy as np
11
+
12
+ import torch
13
+ import tqdm
14
+ from examples.textless_nlp.gslm.speech2unit.pretrained.cpc_feature_reader import (
15
+ CpcFeatureReader,
16
+ )
17
+ from examples.textless_nlp.gslm.speech2unit.pretrained.hubert_feature_reader import (
18
+ HubertFeatureReader,
19
+ )
20
+ from examples.textless_nlp.gslm.speech2unit.pretrained.logmel_feature_reader import (
21
+ LogMelFeatureReader,
22
+ )
23
+ from examples.textless_nlp.gslm.speech2unit.pretrained.w2v2_feature_reader import (
24
+ Wav2VecFeatureReader,
25
+ )
26
+
27
+
28
+ def get_feature_reader(feature_type):
29
+ if feature_type == "logmel":
30
+ return LogMelFeatureReader
31
+ elif feature_type == "hubert":
32
+ return HubertFeatureReader
33
+ elif feature_type == "w2v2":
34
+ return Wav2VecFeatureReader
35
+ elif feature_type == "cpc":
36
+ return CpcFeatureReader
37
+ else:
38
+ raise NotImplementedError(f"{feature_type} is not supported.")
39
+
40
+
41
+ def get_feature_iterator(
42
+ feature_type, checkpoint_path, layer, manifest_path, sample_pct, channel_id
43
+ ):
44
+ feature_reader_cls = get_feature_reader(feature_type)
45
+ with open(manifest_path, "r") as fp:
46
+ lines = fp.read().split("\n")
47
+ root = lines.pop(0).strip()
48
+ file_path_list = [
49
+ os.path.join(root, line.split("\t")[0])
50
+ for line in lines
51
+ if len(line) > 0
52
+ ]
53
+ if sample_pct < 1.0:
54
+ file_path_list = random.sample(
55
+ file_path_list, int(sample_pct * len(file_path_list))
56
+ )
57
+ num_files = len(file_path_list)
58
+ reader = feature_reader_cls(
59
+ checkpoint_path=checkpoint_path, layer=layer
60
+ )
61
+
62
+ def iterate():
63
+ for file_path in file_path_list:
64
+ feats = reader.get_feats(file_path, channel_id=channel_id)
65
+ yield feats.cpu().numpy()
66
+
67
+ return iterate, num_files
68
+
69
+
70
+ def get_features(
71
+ feature_type, checkpoint_path, layer, manifest_path, sample_pct, flatten, channel_id
72
+ ):
73
+ generator, num_files = get_feature_iterator(
74
+ feature_type=feature_type,
75
+ checkpoint_path=checkpoint_path,
76
+ layer=layer,
77
+ manifest_path=manifest_path,
78
+ sample_pct=sample_pct,
79
+ channel_id=channel_id
80
+ )
81
+ iterator = generator()
82
+
83
+ features_list = []
84
+ for features in tqdm.tqdm(iterator, total=num_files):
85
+ features_list.append(features)
86
+
87
+ # Explicit clean up
88
+ del iterator
89
+ del generator
90
+ gc.collect()
91
+ torch.cuda.empty_cache()
92
+
93
+ if flatten:
94
+ return np.concatenate(features_list)
95
+
96
+ return features_list
97
+
98
+
99
+ def get_and_dump_features(
100
+ feature_type,
101
+ checkpoint_path,
102
+ layer,
103
+ manifest_path,
104
+ sample_pct,
105
+ flatten,
106
+ out_features_path,
107
+ ):
108
+ # Feature extraction
109
+ features_batch = get_features(
110
+ feature_type=feature_type,
111
+ checkpoint_path=checkpoint_path,
112
+ layer=layer,
113
+ manifest_path=manifest_path,
114
+ sample_pct=sample_pct,
115
+ flatten=flatten,
116
+ )
117
+
118
+ # Save features
119
+ out_dir_path = os.path.dirname(out_features_path)
120
+ os.makedirs(out_dir_path, exist_ok=True)
121
+ shutil.copyfile(
122
+ manifest_path,
123
+ os.path.join(out_dir_path, os.path.basename(manifest_path)),
124
+ )
125
+ np.save(out_features_path, features_batch)
126
+
127
+ return features_batch
fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/w2v2_feature_reader.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import fairseq
8
+ import soundfile as sf
9
+
10
+
11
+ class Wav2VecFeatureReader:
12
+ """
13
+ Wrapper class to run inference on Wav2Vec 2.0 model.
14
+ Helps extract features for a given audio file.
15
+ """
16
+
17
+ def __init__(self, checkpoint_path, layer, use_cuda=True):
18
+ state = fairseq.checkpoint_utils.load_checkpoint_to_cpu(
19
+ checkpoint_path
20
+ )
21
+
22
+ w2v_args = state["args"]
23
+ self.task = fairseq.tasks.setup_task(w2v_args)
24
+ model = self.task.build_model(w2v_args)
25
+ model.load_state_dict(state["model"], strict=True)
26
+ model.eval()
27
+ self.model = model
28
+ self.layer = layer
29
+ self.use_cuda = use_cuda
30
+ if self.use_cuda:
31
+ self.model.cuda()
32
+
33
+ def read_audio(self, fname, channel_id=None):
34
+ wav, sr = sf.read(fname)
35
+ if channel_id is not None:
36
+ assert wav.ndim == 2, \
37
+ f"Expected stereo input when channel_id is given ({fname})"
38
+ assert channel_id in [1, 2], \
39
+ "channel_id is expected to be in [1, 2]"
40
+ wav = wav[:, channel_id-1]
41
+ if wav.ndim == 2:
42
+ wav = wav.mean(-1)
43
+ assert wav.ndim == 1, wav.ndim
44
+ assert sr == self.task.cfg.sample_rate, sr
45
+ return wav
46
+
47
+ def get_feats(self, file_path, channel_id=None):
48
+ x = self.read_audio(file_path, channel_id)
49
+ with torch.no_grad():
50
+ source = torch.from_numpy(x).view(1, -1).float()
51
+ if self.use_cuda:
52
+ source = source.cuda()
53
+ res = self.model(
54
+ source=source, mask=False, features_only=True, layer=self.layer
55
+ )
56
+ return res["layer_results"][self.layer][0].squeeze(1)
fairseq/examples/textless_nlp/gslm/tools/README.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GSLM Tools
2
+
3
+ ## Resynthesis
4
+ You can use the command line tool below to input an audio file and get the resynthesized audio. This tool implements the unsupervised method for resynthesis described in the paper. The way to invoke the command line tool is shown below.
5
+ ```
6
+ FAIRSEQ_ROOT=<path_to_your_fairseq_repo_root>
7
+ TYPE=<one_of_logmel/cpc/hubert/w2v2>
8
+ ACOUSTIC_MODEL_PATH=<path_of_pretrained_acoustic_model>
9
+ LAYER=<layer_of_acoustic_model_to_extract_features_from>
10
+ KM_MODEL_PATH=<output_path_of_the_kmeans_model>
11
+ TTS_MODEL_PATH=<unit2speech_model_file_path>
12
+ # A text file containing the codes, one per line
13
+ CODE_DICT_PATH=<unit2speech_code_dict_path>
14
+ WAVEGLOW_PATH=<path_where_you_have_downloaded_waveglow_checkpoint>
15
+
16
+ PYTHONPATH=${FAIRSEQ_ROOT}:${FAIRSEQ_ROOT}/examples/textless_nlp/gslm/unit2speech python ${FAIRSEQ_ROOT}/examples/textless_nlp/gslm/tools/resynthesize_speech.py \
17
+ --feature_type $TYPE \
18
+ --acoustic_model_path $ACOUSTIC_MODEL_PATH \
19
+ --layer $LAYER \
20
+ --kmeans_model_path $KM_MODEL_PATH \
21
+ --tts_model_path $TTS_MODEL_PATH \
22
+ --code_dict_path $CODE_DICT_PATH \
23
+ --waveglow_path $WAVEGLOW_PATH \
24
+ --max_decoder_steps 2000
25
+ ```
fairseq/examples/textless_nlp/gslm/tools/resynthesize_speech.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import gc
8
+ import logging
9
+ import os
10
+
11
+ import joblib
12
+ import soundfile as sf
13
+ import torch
14
+ from examples.textless_nlp.gslm.speech2unit.pretrained.utils import get_feature_reader
15
+ from examples.textless_nlp.gslm.unit2speech.tts_data import TacotronInputDataset
16
+ from examples.textless_nlp.gslm.unit2speech.utils import (
17
+ load_tacotron,
18
+ load_waveglow,
19
+ synthesize_audio,
20
+ )
21
+
22
+
23
+ def get_logger():
24
+ log_format = "[%(asctime)s] [%(levelname)s]: %(message)s"
25
+ logging.basicConfig(format=log_format, level=logging.INFO)
26
+ logger = logging.getLogger(__name__)
27
+ return logger
28
+
29
+
30
+ def get_parser():
31
+ parser = argparse.ArgumentParser(description="GSLM U2S tool")
32
+ parser.add_argument(
33
+ "--feature_type",
34
+ type=str,
35
+ choices=["logmel", "hubert", "w2v2", "cpc"],
36
+ default=None,
37
+ required=True,
38
+ help="Acoustic feature type",
39
+ )
40
+ parser.add_argument(
41
+ "--acoustic_model_path",
42
+ type=str,
43
+ help="Pretrained acoustic model checkpoint",
44
+ )
45
+ parser.add_argument("--layer", type=int, help="Layer of acoustic model")
46
+ parser.add_argument(
47
+ "--kmeans_model_path",
48
+ type=str,
49
+ required=True,
50
+ help="K-means model file path to use for inference",
51
+ )
52
+ parser.add_argument(
53
+ "--tts_model_path",
54
+ type=str,
55
+ help="TTS model file path to use for inference",
56
+ )
57
+ parser.add_argument(
58
+ "--code_dict_path",
59
+ type=str,
60
+ help="Code dict file path to use for inference",
61
+ )
62
+ parser.add_argument(
63
+ "--waveglow_path",
64
+ type=str,
65
+ help="Waveglow (vocoder) model file path to use for inference",
66
+ )
67
+ parser.add_argument("--max_decoder_steps", type=int, default=2000)
68
+ parser.add_argument("--denoiser_strength", type=float, default=0.1)
69
+ return parser
70
+
71
+
72
+ ################################################
73
+ def main(args, logger):
74
+ # Acoustic Model
75
+ logger.info(f"Loading acoustic model from {args.tts_model_path}...")
76
+ feature_reader_cls = get_feature_reader(args.feature_type)
77
+ reader = feature_reader_cls(
78
+ checkpoint_path=args.acoustic_model_path, layer=args.layer
79
+ )
80
+
81
+ # K-means Model
82
+ logger.info(f"Loading K-means model from {args.kmeans_model_path} ...")
83
+ kmeans_model = joblib.load(open(args.kmeans_model_path, "rb"))
84
+ kmeans_model.verbose = False
85
+
86
+ # TTS Model
87
+ logger.info(f"Loading TTS model from {args.tts_model_path}...")
88
+ tacotron_model, sample_rate, hparams = load_tacotron(
89
+ tacotron_model_path=args.tts_model_path,
90
+ max_decoder_steps=args.max_decoder_steps,
91
+ )
92
+
93
+ # Waveglow Model
94
+ logger.info(f"Loading Waveglow model from {args.waveglow_path}...")
95
+ waveglow, denoiser = load_waveglow(waveglow_path=args.waveglow_path)
96
+
97
+ # Dataset
98
+ if not os.path.exists(hparams.code_dict):
99
+ hparams.code_dict = args.code_dict_path
100
+ tts_dataset = TacotronInputDataset(hparams)
101
+
102
+ iters = 0
103
+ while True:
104
+ in_file_path = input("Input: Enter the full file path of audio file...\n")
105
+ out_file_path = input("Output: Enter the full file path of audio file...\n")
106
+ feats = reader.get_feats(in_file_path).cpu().numpy()
107
+ iters += 1
108
+ if iters == 1000:
109
+ gc.collect()
110
+ torch.cuda.empty_cache()
111
+
112
+ quantized_units = kmeans_model.predict(feats)
113
+ quantized_units_str = " ".join(map(str, quantized_units))
114
+
115
+ tts_input = tts_dataset.get_tensor(quantized_units_str)
116
+ mel, aud, aud_dn, has_eos = synthesize_audio(
117
+ tacotron_model,
118
+ waveglow,
119
+ denoiser,
120
+ tts_input.unsqueeze(0),
121
+ strength=args.denoiser_strength,
122
+ )
123
+ sf.write(f"{out_file_path}", aud_dn[0].cpu().float().numpy(), sample_rate)
124
+ logger.info("Resynthesis done!\n")
125
+
126
+
127
+ if __name__ == "__main__":
128
+ parser = get_parser()
129
+ args = parser.parse_args()
130
+ logger = get_logger()
131
+ logger.info(args)
132
+ main(args, logger)
fairseq/examples/textless_nlp/gslm/ulm/README.md ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Unit Language Model (ULM)
2
+
3
+ Here you can find links to the pre-trained ULMs and instructions on training new models using fairseq. At the end of the page, we also share how to run sampling for those models and provide pointers to the transcribed prompts we used.
4
+
5
+ ## Pre-trained models
6
+
7
+ Using the links below, you can download pre-trained models for various unit types and vocabulary sizes:
8
+
9
+ | | 50 | 100 | 200
10
+ |-|-|-|-
11
+ | LogMel Filterbank | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/lm_km50/logmel50_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/lm_km100/logmel100_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/lm_km200/logmel200_lm.tgz)
12
+ | Modified CPC | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/lm_km50/cpc50_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/lm_km100/cpc100_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/lm_km200/cpc200_lm.tgz)
13
+ | HuBERT | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/lm_km50/hubert50_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/lm_km100/hubert100_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/lm_km200/hubert200_lm.tgz)
14
+ | Wav2Vec 2.0 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/lm_km50/w2v2_50_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/lm_km100/w2v2_100_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/lm_km200/w2v2_200_lm.tgz)
15
+
16
+
17
+ ## Preprocessing data
18
+ Assuming that unit-transcribed train, valid, and test sets are located in `data/train.txt`, `data/valid.txt`, and `data/test.txt`, respectively,
19
+ we run the following command to get a preprocessed version of the datast in `data-bin`:
20
+
21
+ ```bash
22
+ fairseq-preprocess --only-source \
23
+ --trainpref data/train.txt --validpref data/valid.txt --testpref data/test.txt \
24
+ --destdir data-bin/ --workers 40
25
+ ```
26
+ As a result, the `data-bin` directory should appear.
27
+
28
+ ## Fitting a Unit Language Model (ULM)
29
+ As an ULM, we train a standard fairseq Transformer LM. Assuming 8 GPUs used for training, a good starting point for an ULM training would be:
30
+ ```bash
31
+ fairseq-train data-bin/ \
32
+ --task=language_modeling \
33
+ --arch=transformer_lm_big \
34
+ --share-decoder-input-output-embed \
35
+ --dropout=0.1 \
36
+ --attention-dropout=0.1 \
37
+ --optimizer=adam \
38
+ --adam-betas='(0.9, 0.98)' \
39
+ --clip-norm=1.0 \
40
+ --lr=0.0005 \
41
+ --lr-scheduler=inverse_sqrt \
42
+ --warmup-updates=4000 \
43
+ --warmup-init-lr=1e-07 \
44
+ --tokens-per-sample=3072 \
45
+ --update-freq=16 \
46
+ --max-tokens=4096 \
47
+ --num-workers=4 \
48
+ --skip-invalid-size-inputs-valid-test \
49
+ --max-update=500000 \
50
+ --log-interval=10 \
51
+ --seed=100501 \
52
+ --fp16 \
53
+ --sample-break-mode=eos
54
+ ```
55
+ This command will train a Transformer-large model (12 layers). You can train other standard LM models provided by fairseq, e.g. specify `--arch=transformer_lm` to train a smaller (6-layer) Transformer model. When training with a different number of GPUs, it might be a good idea to adjust the `update-freq` parameter. To save the GPU memory at an expense of additional computation, it can be useful to enable activation checkpointing with `--checkpoint-activations`.
56
+
57
+ ## Sampling from an ULM
58
+ Once an ULM was trained, we can use it for generating new utterances. Suppose, that the prompts are given in a file named `prompts.txt`. Then we can sample continuations by running the following command:
59
+
60
+ ```bash
61
+ python sample.py data-bin/ \
62
+ --path=checkpoints/checkpoint_best.pt --task=language_modeling --sampling --temperature=0.7 \
63
+ --seed=1 --prompts=prompts.txt --output=samples.txt --max-len-a=0 --max-len-b=500 \
64
+ --prefix-size=-1 --batch-size=16 --fp16 --samples-per-prompt=10
65
+ ```
66
+ Here, `--prefix-size` controls the number of tokens that are used to prime the ULM. When set to a positive value, the sampling script will take first `prefix-size` tokens to prompt the ULM; with `0` it runs unconditional sampling and with `-1` the entire prompt is used.
67
+ `--samples-per-prompt` specifies how many utterances are generated with every prompt which can be useful when generating multiple prompt continuations. In this command, `--max-len-a` and `--max-len-b` control the number of generated tokens.
68
+
69
+ When using a pretrained model from above, `data-bin` should point to the unpacked directory (with `dict.txt` file).
70
+
71
+ Evaluation-time, to generate prompts, we used utterances from LibriSpeech dev-clean and test-clean that are longer than 6s. We took first 3s from an utterance as a prompt. Unit transcripts of those prompts can be downloaded here: [[dev]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/dev_prompts.tgz) [[test]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/test_prompts.tgz)
72
+
fairseq/examples/textless_nlp/gslm/ulm/sample.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Sample from a trained LM; hacked fairseq-interactive
8
+ """
9
+ from collections import namedtuple
10
+ import os
11
+ import ast
12
+ import numpy as np
13
+
14
+ from fairseq import checkpoint_utils, options, tasks, utils
15
+
16
+ import tqdm
17
+
18
+ Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
19
+ Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
20
+
21
+
22
+ def make_batches(lines, args, task, max_positions):
23
+ tokens = [
24
+ task.source_dictionary.encode_line(
25
+ src_str, add_if_not_exist=False
26
+ ).long()
27
+ for src_str in lines
28
+ ]
29
+ lengths = [t.numel() for t in tokens]
30
+ itr = task.get_batch_iterator(
31
+ dataset=task.build_dataset_for_inference(tokens, lengths),
32
+ max_tokens=args.dataset.max_tokens,
33
+ max_sentences=args.dataset.batch_size,
34
+ max_positions=max_positions,
35
+ ignore_invalid_inputs=args.dataset.skip_invalid_size_inputs_valid_test
36
+ ).next_epoch_itr(shuffle=False)
37
+ for batch in itr:
38
+ yield Batch(
39
+ ids=batch['id'],
40
+ src_tokens=batch['net_input']['src_tokens'], src_lengths=batch['net_input']['src_lengths'],
41
+ )
42
+
43
+
44
+ def main(args):
45
+ arg_prompts = args.prompts
46
+ arg_output = args.output
47
+ arg_debug = args.debug
48
+ arg_sample_size = args.samples_per_prompt
49
+
50
+ try:
51
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
52
+ args = convert_namespace_to_omegaconf(args)
53
+ except:
54
+ pass
55
+
56
+ # if args.max_tokens is None and args.max_sentences is None:
57
+ if args.common.seed is not None:
58
+ np.random.seed(args.common.seed)
59
+ utils.set_torch_seed(args.common.seed)
60
+
61
+ if args.generation.sampling:
62
+ args.generation.nbest = args.generation.beam = arg_sample_size
63
+
64
+ task = tasks.setup_task(args.task)
65
+
66
+ overrides = ast.literal_eval(args.common_eval.model_overrides)
67
+
68
+ models, _model_args = checkpoint_utils.load_model_ensemble(
69
+ args.common_eval.path.split(os.pathsep),
70
+ arg_overrides=overrides,
71
+ task=task,
72
+ suffix=getattr(args, "checkpoint_suffix", ""),
73
+ )
74
+
75
+ # Set dictionaries
76
+ src_dict = task.source_dictionary
77
+ tgt_dict = task.target_dictionary
78
+
79
+ # Optimize ensemble for generation
80
+ for model in models:
81
+ model.prepare_for_inference_(args)
82
+ model.cuda()
83
+
84
+ # Load alignment dictionary for unknown word replacement
85
+ # (None if no unknown word replacement, empty if no path to align dictionary)
86
+ align_dict = utils.load_align_dict(args.generation.replace_unk)
87
+
88
+ max_positions = utils.resolve_max_positions(
89
+ task.max_positions(),
90
+ *[model.max_positions() for model in models]
91
+ )
92
+
93
+ output_file = open(arg_output, 'w')
94
+
95
+ with open(arg_prompts, 'r') as fin:
96
+ lines = fin.readlines()
97
+
98
+ split = [x.split('|', 1) for x in lines]
99
+ seq_id = [x[0] for x in split]
100
+ prompts = [x[1] for x in split]
101
+
102
+ if args.generation.prefix_size >= 0:
103
+ prompts = [' '.join(l.split()[:args.generation.prefix_size])
104
+ for l in prompts]
105
+
106
+ if arg_debug:
107
+ prompts = prompts[:10]
108
+
109
+ generator = task.build_generator(models, args.generation)
110
+
111
+ start_id = 0
112
+ pbar = tqdm.tqdm(total=len(prompts))
113
+ for batch in make_batches(prompts, args, task, max_positions):
114
+ src_tokens = batch.src_tokens
115
+ src_lengths = batch.src_lengths
116
+ src_tokens = src_tokens.cuda()
117
+ src_lengths = src_lengths.cuda()
118
+
119
+ sample = {
120
+ 'net_input': {
121
+ 'src_tokens': src_tokens,
122
+ 'src_lengths': src_lengths,
123
+ },
124
+ }
125
+
126
+ results = []
127
+ translations = task.inference_step(generator, models, sample)
128
+ for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
129
+ src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
130
+ results.append((i + start_id, src_tokens_i, hypos))
131
+
132
+ # sort output to match input order
133
+ for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
134
+ if src_dict is not None:
135
+ src_str = src_dict.string(
136
+ src_tokens, args.common_eval.post_process)
137
+
138
+ # Process top predictions
139
+ for hypo_id, hypo in enumerate(hypos):
140
+ _hypo_tokens, hypo_str, _alignment = utils.post_process_prediction(
141
+ hypo_tokens=hypo['tokens'].int().cpu(),
142
+ src_str=src_str,
143
+ alignment=hypo['alignment'],
144
+ align_dict=align_dict,
145
+ tgt_dict=tgt_dict,
146
+ remove_bpe=args.common_eval.post_process,
147
+ )
148
+
149
+ detok_hypo_str = hypo_str
150
+ utterance = detok_hypo_str
151
+ print(f'{seq_id[id]}__{hypo_id}|{utterance}', file=output_file)
152
+ pbar.update(1)
153
+ start_id += len(results)
154
+
155
+ # output_file.close()
156
+
157
+
158
+ def cli_main():
159
+ parser = options.get_interactive_generation_parser()
160
+ parser.add_argument('--prompts', type=str, default=None, required=True)
161
+ parser.add_argument('--output', type=str, default=None, required=True)
162
+ parser.add_argument('--debug', action='store_true')
163
+ parser.add_argument('--samples-per-prompt', type=int, default=1)
164
+
165
+ args = options.parse_args_and_arch(parser)
166
+
167
+ np.random.seed(args.seed)
168
+ utils.set_torch_seed(args.seed)
169
+
170
+ main(args)
171
+
172
+
173
+ if __name__ == '__main__':
174
+ cli_main()
fairseq/examples/textless_nlp/gslm/unit2speech/README.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Unit to Speech Model (unit2speech)
2
+
3
+ Unit to speech model is modified Tacotron2 model that learns to synthesize speech from discrete speech units. All models are trained on quantized [LJSpeech](https://keithito.com/LJ-Speech-Dataset/).
4
+
5
+ Upstream Units | Download Links | model md5
6
+ |-|-|-
7
+ Log Mel Filterbank + KM50 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km50/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km50/code_dict) | 932b3b8527c0125f5f964b57762eba49
8
+ Log Mel Filterbank + KM100 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km100/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km100/code_dict) | cde0b0d278a39011d0acbd5df27abdf4
9
+ Log Mel Filterbank + KM200 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km200/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km200/code_dict) | dba0f1d4de64bc7976718834010b23e7
10
+ Modified CPC + KM50 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km50/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km50/code_dict) | a585e8dd8890ea56164f17635dd8e613
11
+ Modified CPC + KM100 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km100/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km100/code_dict) | 5c0ee2869b4f483d17f37f1a41a548e0
12
+ Modified CPC + KM200 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km200/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km200/code_dict) | 2f0c9951cf37020d9464514bff48bc5d
13
+ HuBERT Base + KM50 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km50/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km50/code_dict) | 85ffce8baec5aa90035ab696fe676fce
14
+ HuBERT Base + KM100 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km100/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km100/code_dict) | df4a9c6ffd1bb00c91405432c234aba3
15
+ HuBERT Base + KM200 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km200/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km200/code_dict) | ac72f2c0c563589819bec116c7f8d274
16
+ wav2vec 2.0 Large + KM50 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km50/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km50/code_dict) | e3503d0ad822b2c24b89f68b857fedff
17
+ wav2vec 2.0 Large + KM100 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km100/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km100/code_dict) | eb3666e456ae4c96bf2a1eec825c13ed
18
+ wav2vec 2.0 Large + KM200 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km200/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km200/code_dict) | 777d343e963c4d64f04d78eef032f4e8
19
+
20
+ ## Run inference using a unit2speech model
21
+ * Install librosa, unidecode and inflect using `pip install librosa, unidecode, inflect`
22
+ * Download [Waveglow checkpoint](https://dl.fbaipublicfiles.com/textless_nlp/gslm/waveglow_256channels_new.pt). This is the vocoder.
23
+
24
+ Sample commnd to run inference using trained unit2speech models. Please note that the quantized audio to synthesized should be using the same units as the unit2speech model was trained with.
25
+ ```
26
+ FAIRSEQ_ROOT=<path_to_your_fairseq_repo_root>
27
+ TTS_MODEL_PATH=<unit2speech_model_file_path>
28
+ QUANTIZED_UNIT_PATH=<quantized_audio_file_path>
29
+ OUT_DIR=<dir_to_dump_synthesized_audio_files>
30
+ WAVEGLOW_PATH=<path_where_you_have_downloaded_waveglow_checkpoint>
31
+ CODE_DICT_PATH=<unit2speech_code_dict_path>
32
+
33
+ PYTHONPATH=${FAIRSEQ_ROOT}:${FAIRSEQ_ROOT}/examples/textless_nlp/gslm/unit2speech python ${FAIRSEQ_ROOT}/examples/textless_nlp/gslm/unit2speech/synthesize_audio_from_units.py \
34
+ --tts_model_path $TTS_MODEL_PATH \
35
+ --quantized_unit_path $QUANTIZED_UNIT_PATH \
36
+ --out_audio_dir $OUT_DIR \
37
+ --waveglow_path $WAVEGLOW_PATH \
38
+ --code_dict_path $CODE_DICT_PATH \
39
+ --max_decoder_steps 2000
40
+ ```
fairseq/examples/textless_nlp/gslm/unit2speech/convert_to_16k.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shlex
3
+ import subprocess
4
+ import progressbar
5
+ from time import time
6
+ from pathlib import Path
7
+
8
+ def find_all_files(path_dir, extension):
9
+ out = []
10
+ for root, dirs, filenames in os.walk(path_dir):
11
+ for f in filenames:
12
+ if f.endswith(extension):
13
+ out.append(((str(Path(f).stem)), os.path.join(root, f)))
14
+ return out
15
+
16
+ def convert16k(inputfile, outputfile16k):
17
+ command = ('sox -c 1 -b 16 {} -t wav {} rate 16k'.format(inputfile, outputfile16k))
18
+ subprocess.call(shlex.split(command))
19
+
20
+ if __name__ == "__main__":
21
+ import argparse
22
+
23
+ parser = argparse.ArgumentParser(description='Convert to wav 16k audio using sox.')
24
+ parser.add_argument('input_dir', type=str,
25
+ help='Path to the input dir.')
26
+ parser.add_argument('output_dir', type=str,
27
+ help='Path to the output dir.')
28
+ parser.add_argument('--extension', type=str, default='wav',
29
+ help='Audio file extension in the input. Default: mp3')
30
+ args = parser.parse_args()
31
+
32
+ # Find all sequences
33
+ print(f"Finding all audio files with extension '{args.extension}' from {args.input_dir}...")
34
+ audio_files = find_all_files(args.input_dir, args.extension)
35
+ print(f"Done! Found {len(audio_files)} files.")
36
+
37
+ # Convert to relative path
38
+ audio_files = [os.path.relpath(file[-1], start=args.input_dir) for file in audio_files]
39
+
40
+ # Create all the directories needed
41
+ rel_dirs_set = set([os.path.dirname(file) for file in audio_files])
42
+ for rel_dir in rel_dirs_set:
43
+ Path(os.path.join(args.output_dir, rel_dir)).mkdir(parents=True, exist_ok=True)
44
+
45
+ # Converting wavs files
46
+ print("Converting the audio to wav files...")
47
+ bar = progressbar.ProgressBar(maxval=len(audio_files))
48
+ bar.start()
49
+ start_time = time()
50
+ for index, file in enumerate(audio_files):
51
+ bar.update(index)
52
+ input_file = os.path.join(args.input_dir, file)
53
+ output_file = os.path.join(args.output_dir, os.path.splitext(file)[0]+".wav")
54
+ convert16k(input_file, output_file)
55
+ bar.finish()
56
+ print(f"...done {len(audio_files)} files in {time()-start_time} seconds.")
fairseq/examples/textless_nlp/gslm/unit2speech/glow.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *****************************************************************************
2
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Redistribution and use in source and binary forms, with or without
5
+ # modification, are permitted provided that the following conditions are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of the NVIDIA CORPORATION nor the
12
+ # names of its contributors may be used to endorse or promote products
13
+ # derived from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16
+ # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17
+ # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18
+ # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19
+ # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20
+ # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21
+ # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22
+ # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24
+ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
+ #
26
+ # *****************************************************************************
27
+ import copy
28
+ import torch
29
+ from torch.autograd import Variable
30
+ import torch.nn.functional as F
31
+
32
+
33
+ @torch.jit.script
34
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
35
+ n_channels_int = n_channels[0]
36
+ in_act = input_a+input_b
37
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
38
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
39
+ acts = t_act * s_act
40
+ return acts
41
+
42
+
43
+ class WaveGlowLoss(torch.nn.Module):
44
+ def __init__(self, sigma=1.0):
45
+ super(WaveGlowLoss, self).__init__()
46
+ self.sigma = sigma
47
+
48
+ def forward(self, model_output):
49
+ z, log_s_list, log_det_W_list = model_output
50
+ for i, log_s in enumerate(log_s_list):
51
+ if i == 0:
52
+ log_s_total = torch.sum(log_s)
53
+ log_det_W_total = log_det_W_list[i]
54
+ else:
55
+ log_s_total = log_s_total + torch.sum(log_s)
56
+ log_det_W_total += log_det_W_list[i]
57
+
58
+ loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - log_s_total - log_det_W_total
59
+ return loss/(z.size(0)*z.size(1)*z.size(2))
60
+
61
+
62
+ class Invertible1x1Conv(torch.nn.Module):
63
+ """
64
+ The layer outputs both the convolution, and the log determinant
65
+ of its weight matrix. If reverse=True it does convolution with
66
+ inverse
67
+ """
68
+ def __init__(self, c):
69
+ super(Invertible1x1Conv, self).__init__()
70
+ self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0,
71
+ bias=False)
72
+
73
+ # Sample a random orthonormal matrix to initialize weights
74
+ _qr = torch.linalg.qr if torch.__version__ >= "1.8" else torch.qr
75
+ W = _qr(torch.FloatTensor(c, c).normal_())[0]
76
+
77
+ # Ensure determinant is 1.0 not -1.0
78
+ if torch.det(W) < 0:
79
+ W[:,0] = -1*W[:,0]
80
+ W = W.view(c, c, 1)
81
+ self.conv.weight.data = W
82
+
83
+ def forward(self, z, reverse=False):
84
+ # shape
85
+ batch_size, group_size, n_of_groups = z.size()
86
+
87
+ W = self.conv.weight.squeeze()
88
+
89
+ if reverse:
90
+ if not hasattr(self, 'W_inverse'):
91
+ # Reverse computation
92
+ W_inverse = W.float().inverse()
93
+ W_inverse = Variable(W_inverse[..., None])
94
+ if z.type() == 'torch.cuda.HalfTensor':
95
+ W_inverse = W_inverse.half()
96
+ self.W_inverse = W_inverse
97
+ z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
98
+ return z
99
+ else:
100
+ # Forward computation
101
+ log_det_W = batch_size * n_of_groups * torch.logdet(W)
102
+ z = self.conv(z)
103
+ return z, log_det_W
104
+
105
+
106
+ class WN(torch.nn.Module):
107
+ """
108
+ This is the WaveNet like layer for the affine coupling. The primary difference
109
+ from WaveNet is the convolutions need not be causal. There is also no dilation
110
+ size reset. The dilation only doubles on each layer
111
+ """
112
+ def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
113
+ kernel_size):
114
+ super(WN, self).__init__()
115
+ assert(kernel_size % 2 == 1)
116
+ assert(n_channels % 2 == 0)
117
+ self.n_layers = n_layers
118
+ self.n_channels = n_channels
119
+ self.in_layers = torch.nn.ModuleList()
120
+ self.res_skip_layers = torch.nn.ModuleList()
121
+
122
+ start = torch.nn.Conv1d(n_in_channels, n_channels, 1)
123
+ start = torch.nn.utils.weight_norm(start, name='weight')
124
+ self.start = start
125
+
126
+ # Initializing last layer to 0 makes the affine coupling layers
127
+ # do nothing at first. This helps with training stability
128
+ end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1)
129
+ end.weight.data.zero_()
130
+ end.bias.data.zero_()
131
+ self.end = end
132
+
133
+ cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1)
134
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
135
+
136
+ for i in range(n_layers):
137
+ dilation = 2 ** i
138
+ padding = int((kernel_size*dilation - dilation)/2)
139
+ in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size,
140
+ dilation=dilation, padding=padding)
141
+ in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
142
+ self.in_layers.append(in_layer)
143
+
144
+
145
+ # last one is not necessary
146
+ if i < n_layers - 1:
147
+ res_skip_channels = 2*n_channels
148
+ else:
149
+ res_skip_channels = n_channels
150
+ res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
151
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
152
+ self.res_skip_layers.append(res_skip_layer)
153
+
154
+ def forward(self, forward_input):
155
+ audio, spect = forward_input
156
+ audio = self.start(audio)
157
+ output = torch.zeros_like(audio)
158
+ n_channels_tensor = torch.IntTensor([self.n_channels])
159
+
160
+ spect = self.cond_layer(spect)
161
+
162
+ for i in range(self.n_layers):
163
+ spect_offset = i*2*self.n_channels
164
+ acts = fused_add_tanh_sigmoid_multiply(
165
+ self.in_layers[i](audio),
166
+ spect[:,spect_offset:spect_offset+2*self.n_channels,:],
167
+ n_channels_tensor)
168
+
169
+ res_skip_acts = self.res_skip_layers[i](acts)
170
+ if i < self.n_layers - 1:
171
+ audio = audio + res_skip_acts[:,:self.n_channels,:]
172
+ output = output + res_skip_acts[:,self.n_channels:,:]
173
+ else:
174
+ output = output + res_skip_acts
175
+
176
+ return self.end(output)
177
+
178
+
179
+ class WaveGlow(torch.nn.Module):
180
+ def __init__(self, n_mel_channels, n_flows, n_group, n_early_every,
181
+ n_early_size, WN_config):
182
+ super(WaveGlow, self).__init__()
183
+
184
+ self.upsample = torch.nn.ConvTranspose1d(n_mel_channels,
185
+ n_mel_channels,
186
+ 1024, stride=256)
187
+ assert(n_group % 2 == 0)
188
+ self.n_flows = n_flows
189
+ self.n_group = n_group
190
+ self.n_early_every = n_early_every
191
+ self.n_early_size = n_early_size
192
+ self.WN = torch.nn.ModuleList()
193
+ self.convinv = torch.nn.ModuleList()
194
+
195
+ n_half = int(n_group/2)
196
+
197
+ # Set up layers with the right sizes based on how many dimensions
198
+ # have been output already
199
+ n_remaining_channels = n_group
200
+ for k in range(n_flows):
201
+ if k % self.n_early_every == 0 and k > 0:
202
+ n_half = n_half - int(self.n_early_size/2)
203
+ n_remaining_channels = n_remaining_channels - self.n_early_size
204
+ self.convinv.append(Invertible1x1Conv(n_remaining_channels))
205
+ self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config))
206
+ self.n_remaining_channels = n_remaining_channels # Useful during inference
207
+
208
+ def forward(self, forward_input):
209
+ """
210
+ forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames
211
+ forward_input[1] = audio: batch x time
212
+ """
213
+ spect, audio = forward_input
214
+
215
+ # Upsample spectrogram to size of audio
216
+ spect = self.upsample(spect)
217
+ assert(spect.size(2) >= audio.size(1))
218
+ if spect.size(2) > audio.size(1):
219
+ spect = spect[:, :, :audio.size(1)]
220
+
221
+ spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
222
+ spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
223
+
224
+ audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1)
225
+ output_audio = []
226
+ log_s_list = []
227
+ log_det_W_list = []
228
+
229
+ for k in range(self.n_flows):
230
+ if k % self.n_early_every == 0 and k > 0:
231
+ output_audio.append(audio[:,:self.n_early_size,:])
232
+ audio = audio[:,self.n_early_size:,:]
233
+
234
+ audio, log_det_W = self.convinv[k](audio)
235
+ log_det_W_list.append(log_det_W)
236
+
237
+ n_half = int(audio.size(1)/2)
238
+ audio_0 = audio[:,:n_half,:]
239
+ audio_1 = audio[:,n_half:,:]
240
+
241
+ output = self.WN[k]((audio_0, spect))
242
+ log_s = output[:, n_half:, :]
243
+ b = output[:, :n_half, :]
244
+ audio_1 = torch.exp(log_s)*audio_1 + b
245
+ log_s_list.append(log_s)
246
+
247
+ audio = torch.cat([audio_0, audio_1],1)
248
+
249
+ output_audio.append(audio)
250
+ return torch.cat(output_audio,1), log_s_list, log_det_W_list
251
+
252
+ def infer(self, spect, sigma=1.0):
253
+ spect = self.upsample(spect)
254
+ # trim conv artifacts. maybe pad spec to kernel multiple
255
+ time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
256
+ spect = spect[:, :, :-time_cutoff]
257
+
258
+ spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
259
+ spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
260
+
261
+ if spect.type() == 'torch.cuda.HalfTensor':
262
+ audio = torch.cuda.HalfTensor(spect.size(0),
263
+ self.n_remaining_channels,
264
+ spect.size(2)).normal_()
265
+ else:
266
+ audio = torch.cuda.FloatTensor(spect.size(0),
267
+ self.n_remaining_channels,
268
+ spect.size(2)).normal_()
269
+
270
+ audio = torch.autograd.Variable(sigma*audio)
271
+
272
+ for k in reversed(range(self.n_flows)):
273
+ n_half = int(audio.size(1)/2)
274
+ audio_0 = audio[:,:n_half,:]
275
+ audio_1 = audio[:,n_half:,:]
276
+
277
+ output = self.WN[k]((audio_0, spect))
278
+
279
+ s = output[:, n_half:, :]
280
+ b = output[:, :n_half, :]
281
+ audio_1 = (audio_1 - b)/torch.exp(s)
282
+ audio = torch.cat([audio_0, audio_1],1)
283
+
284
+ audio = self.convinv[k](audio, reverse=True)
285
+
286
+ if k % self.n_early_every == 0 and k > 0:
287
+ if spect.type() == 'torch.cuda.HalfTensor':
288
+ z = torch.cuda.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
289
+ else:
290
+ z = torch.cuda.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
291
+ audio = torch.cat((sigma*z, audio),1)
292
+
293
+ audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data
294
+ return audio
295
+
296
+ @staticmethod
297
+ def remove_weightnorm(model):
298
+ waveglow = model
299
+ for WN in waveglow.WN:
300
+ WN.start = torch.nn.utils.remove_weight_norm(WN.start)
301
+ WN.in_layers = remove(WN.in_layers)
302
+ WN.cond_layer = torch.nn.utils.remove_weight_norm(WN.cond_layer)
303
+ WN.res_skip_layers = remove(WN.res_skip_layers)
304
+ return waveglow
305
+
306
+
307
+ def remove(conv_list):
308
+ new_conv_list = torch.nn.ModuleList()
309
+ for old_conv in conv_list:
310
+ old_conv = torch.nn.utils.remove_weight_norm(old_conv)
311
+ new_conv_list.append(old_conv)
312
+ return new_conv_list
fairseq/examples/textless_nlp/gslm/unit2speech/multiproc.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ import sys
5
+ import subprocess
6
+
7
+ argslist = list(sys.argv)[1:]
8
+ log_dir = argslist[-1]
9
+ num_gpus = torch.cuda.device_count()
10
+ argslist.append('--n_gpus={}'.format(num_gpus))
11
+ workers = []
12
+ job_id = time.strftime("%Y_%m_%d-%H%M%S")
13
+ argslist.append("--group_name=group_{}".format(job_id))
14
+
15
+ print("GPU log directory is {}".format(log_dir))
16
+ os.makedirs(log_dir, exist_ok=True)
17
+ for i in range(num_gpus):
18
+ argslist.append('--rank={}'.format(i))
19
+ stdout = None if i == 0 else open("{}/{}_GPU_{}.log".format(log_dir, job_id, i),
20
+ "w")
21
+ print(argslist)
22
+ p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout)
23
+ workers.append(p)
24
+ argslist = argslist[:-1]
25
+
26
+ for p in workers:
27
+ p.wait()
fairseq/examples/textless_nlp/gslm/unit2speech/synthesize_audio_from_units.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import logging
8
+ import os
9
+
10
+ import soundfile as sf
11
+ from examples.textless_nlp.gslm.unit2speech.tts_data import (
12
+ TacotronInputDataset,
13
+ )
14
+ from examples.textless_nlp.gslm.unit2speech.utils import (
15
+ load_quantized_audio_from_file,
16
+ load_tacotron,
17
+ load_waveglow,
18
+ synthesize_audio,
19
+ )
20
+
21
+
22
+ def get_logger():
23
+ log_format = "[%(asctime)s] [%(levelname)s]: %(message)s"
24
+ logging.basicConfig(format=log_format, level=logging.INFO)
25
+ logger = logging.getLogger(__name__)
26
+ return logger
27
+
28
+
29
+ def get_parser():
30
+ parser = argparse.ArgumentParser(
31
+ description="Wav2Vec 2.0 speech generator."
32
+ )
33
+ parser.add_argument(
34
+ "--quantized_unit_path",
35
+ type=str,
36
+ help="K-means model file path to use for inference",
37
+ )
38
+ parser.add_argument(
39
+ "--tts_model_path",
40
+ type=str,
41
+ help="TTS model file path to use for inference",
42
+ )
43
+ parser.add_argument(
44
+ "--waveglow_path",
45
+ type=str,
46
+ help="Path to the waveglow checkpoint (vocoder).",
47
+ )
48
+ parser.add_argument(
49
+ "--code_dict_path",
50
+ type=str,
51
+ help="Code dict file path to use for inference",
52
+ )
53
+ parser.add_argument("--max_decoder_steps", type=int, default=2000)
54
+ parser.add_argument("--denoiser_strength", type=float, default=0.1)
55
+ parser.add_argument(
56
+ "--out_audio_dir",
57
+ type=str,
58
+ help="Output directory to dump audio files",
59
+ )
60
+
61
+ return parser
62
+
63
+
64
+ def main(args, logger):
65
+ # Load quantized audio
66
+ logger.info(f"Loading quantized audio from {args.quantized_unit_path}...")
67
+ names_batch, quantized_units_batch = load_quantized_audio_from_file(
68
+ file_path=args.quantized_unit_path
69
+ )
70
+
71
+ logger.info(f"Loading TTS model from {args.tts_model_path}...")
72
+ tacotron_model, sample_rate, hparams = load_tacotron(
73
+ tacotron_model_path=args.tts_model_path,
74
+ max_decoder_steps=args.max_decoder_steps,
75
+ )
76
+
77
+ logger.info(f"Loading Waveglow model from {args.waveglow_path}...")
78
+ waveglow, denoiser = load_waveglow(waveglow_path=args.waveglow_path)
79
+
80
+ if not os.path.exists(hparams.code_dict):
81
+ hparams.code_dict = args.code_dict_path
82
+ tts_dataset = TacotronInputDataset(hparams)
83
+
84
+ for name, quantized_units in zip(names_batch, quantized_units_batch):
85
+ quantized_units_str = " ".join(map(str, quantized_units))
86
+ tts_input = tts_dataset.get_tensor(quantized_units_str)
87
+ mel, aud, aud_dn, has_eos = synthesize_audio(
88
+ tacotron_model,
89
+ waveglow,
90
+ denoiser,
91
+ tts_input.unsqueeze(0),
92
+ strength=args.denoiser_strength,
93
+ )
94
+ out_file_path = os.path.join(args.out_audio_dir, f"{name}.wav")
95
+ sf.write(
96
+ f"{out_file_path}", aud_dn[0].cpu().float().numpy(), sample_rate
97
+ )
98
+
99
+
100
+ if __name__ == "__main__":
101
+ parser = get_parser()
102
+ args = parser.parse_args()
103
+ logger = get_logger()
104
+ logger.info(args)
105
+ main(args, logger)
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/__init__.py ADDED
File without changes
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/audio_processing.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from scipy.signal import get_window
4
+ import librosa.util as librosa_util
5
+
6
+
7
+ def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
8
+ n_fft=800, dtype=np.float32, norm=None):
9
+ """
10
+ # from librosa 0.6
11
+ Compute the sum-square envelope of a window function at a given hop length.
12
+
13
+ This is used to estimate modulation effects induced by windowing
14
+ observations in short-time fourier transforms.
15
+
16
+ Parameters
17
+ ----------
18
+ window : string, tuple, number, callable, or list-like
19
+ Window specification, as in `get_window`
20
+
21
+ n_frames : int > 0
22
+ The number of analysis frames
23
+
24
+ hop_length : int > 0
25
+ The number of samples to advance between frames
26
+
27
+ win_length : [optional]
28
+ The length of the window function. By default, this matches `n_fft`.
29
+
30
+ n_fft : int > 0
31
+ The length of each analysis frame.
32
+
33
+ dtype : np.dtype
34
+ The data type of the output
35
+
36
+ Returns
37
+ -------
38
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
39
+ The sum-squared envelope of the window function
40
+ """
41
+ if win_length is None:
42
+ win_length = n_fft
43
+
44
+ n = n_fft + hop_length * (n_frames - 1)
45
+ x = np.zeros(n, dtype=dtype)
46
+
47
+ # Compute the squared window at the desired length
48
+ win_sq = get_window(window, win_length, fftbins=True)
49
+ win_sq = librosa_util.normalize(win_sq, norm=norm)**2
50
+ win_sq = librosa_util.pad_center(win_sq, n_fft)
51
+
52
+ # Fill the envelope
53
+ for i in range(n_frames):
54
+ sample = i * hop_length
55
+ x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
56
+ return x
57
+
58
+
59
+ def griffin_lim(magnitudes, stft_fn, n_iters=30):
60
+ """
61
+ PARAMS
62
+ ------
63
+ magnitudes: spectrogram magnitudes
64
+ stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
65
+ """
66
+
67
+ angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
68
+ angles = angles.astype(np.float32)
69
+ angles = torch.autograd.Variable(torch.from_numpy(angles))
70
+ signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
71
+
72
+ for i in range(n_iters):
73
+ _, angles = stft_fn.transform(signal)
74
+ signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
75
+ return signal
76
+
77
+
78
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
79
+ """
80
+ PARAMS
81
+ ------
82
+ C: compression factor
83
+ """
84
+ return torch.log(torch.clamp(x, min=clip_val) * C)
85
+
86
+
87
+ def dynamic_range_decompression(x, C=1):
88
+ """
89
+ PARAMS
90
+ ------
91
+ C: compression factor used to compress
92
+ """
93
+ return torch.exp(x) / C
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cleaners.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ '''
4
+ Cleaners are transformations that run over the input text at both training and eval time.
5
+
6
+ Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7
+ hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8
+ 1. "english_cleaners" for English text
9
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12
+ the symbols in symbols.py to match your data).
13
+ '''
14
+
15
+ import re
16
+ from unidecode import unidecode
17
+ from .numbers import normalize_numbers
18
+
19
+
20
+ # Regular expression matching whitespace:
21
+ _whitespace_re = re.compile(r'\s+')
22
+
23
+ # List of (regular expression, replacement) pairs for abbreviations:
24
+ _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
25
+ ('mrs', 'misess'),
26
+ ('mr', 'mister'),
27
+ ('dr', 'doctor'),
28
+ ('st', 'saint'),
29
+ ('co', 'company'),
30
+ ('jr', 'junior'),
31
+ ('maj', 'major'),
32
+ ('gen', 'general'),
33
+ ('drs', 'doctors'),
34
+ ('rev', 'reverend'),
35
+ ('lt', 'lieutenant'),
36
+ ('hon', 'honorable'),
37
+ ('sgt', 'sergeant'),
38
+ ('capt', 'captain'),
39
+ ('esq', 'esquire'),
40
+ ('ltd', 'limited'),
41
+ ('col', 'colonel'),
42
+ ('ft', 'fort'),
43
+ ]]
44
+
45
+
46
+ def expand_abbreviations(text):
47
+ for regex, replacement in _abbreviations:
48
+ text = re.sub(regex, replacement, text)
49
+ return text
50
+
51
+
52
+ def expand_numbers(text):
53
+ return normalize_numbers(text)
54
+
55
+
56
+ def lowercase(text):
57
+ return text.lower()
58
+
59
+
60
+ def collapse_whitespace(text):
61
+ return re.sub(_whitespace_re, ' ', text)
62
+
63
+
64
+ def convert_to_ascii(text):
65
+ return unidecode(text)
66
+
67
+
68
+ def basic_cleaners(text):
69
+ '''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
70
+ text = lowercase(text)
71
+ text = collapse_whitespace(text)
72
+ return text
73
+
74
+
75
+ def transliteration_cleaners(text):
76
+ '''Pipeline for non-English text that transliterates to ASCII.'''
77
+ text = convert_to_ascii(text)
78
+ text = lowercase(text)
79
+ text = collapse_whitespace(text)
80
+ return text
81
+
82
+
83
+ def english_cleaners(text):
84
+ '''Pipeline for English text, including number and abbreviation expansion.'''
85
+ text = convert_to_ascii(text)
86
+ text = lowercase(text)
87
+ text = expand_numbers(text)
88
+ text = expand_abbreviations(text)
89
+ text = collapse_whitespace(text)
90
+ return text
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cmudict.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ import re
4
+
5
+
6
+ valid_symbols = [
7
+ 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
8
+ 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
9
+ 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
10
+ 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
11
+ 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
12
+ 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
13
+ 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
14
+ ]
15
+
16
+ _valid_symbol_set = set(valid_symbols)
17
+
18
+
19
+ class CMUDict:
20
+ '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict'''
21
+ def __init__(self, file_or_path, keep_ambiguous=True):
22
+ if isinstance(file_or_path, str):
23
+ with open(file_or_path, encoding='latin-1') as f:
24
+ entries = _parse_cmudict(f)
25
+ else:
26
+ entries = _parse_cmudict(file_or_path)
27
+ if not keep_ambiguous:
28
+ entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
29
+ self._entries = entries
30
+
31
+
32
+ def __len__(self):
33
+ return len(self._entries)
34
+
35
+
36
+ def lookup(self, word):
37
+ '''Returns list of ARPAbet pronunciations of the given word.'''
38
+ return self._entries.get(word.upper())
39
+
40
+
41
+
42
+ _alt_re = re.compile(r'\([0-9]+\)')
43
+
44
+
45
+ def _parse_cmudict(file):
46
+ cmudict = {}
47
+ for line in file:
48
+ if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
49
+ parts = line.split(' ')
50
+ word = re.sub(_alt_re, '', parts[0])
51
+ pronunciation = _get_pronunciation(parts[1])
52
+ if pronunciation:
53
+ if word in cmudict:
54
+ cmudict[word].append(pronunciation)
55
+ else:
56
+ cmudict[word] = [pronunciation]
57
+ return cmudict
58
+
59
+
60
+ def _get_pronunciation(s):
61
+ parts = s.strip().split(' ')
62
+ for part in parts:
63
+ if part not in _valid_symbol_set:
64
+ return None
65
+ return ' '.join(parts)
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/layers.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from librosa.filters import mel as librosa_mel_fn
3
+ from .audio_processing import dynamic_range_compression
4
+ from .audio_processing import dynamic_range_decompression
5
+ from .stft import STFT
6
+ from .utils import get_mask_from_lengths
7
+
8
+
9
+ class LinearNorm(torch.nn.Module):
10
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
11
+ super(LinearNorm, self).__init__()
12
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
13
+
14
+ torch.nn.init.xavier_uniform_(
15
+ self.linear_layer.weight,
16
+ gain=torch.nn.init.calculate_gain(w_init_gain))
17
+
18
+ def forward(self, x):
19
+ return self.linear_layer(x)
20
+
21
+
22
+ class ConvNorm(torch.nn.Module):
23
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
24
+ padding=None, dilation=1, bias=True, w_init_gain='linear'):
25
+ super(ConvNorm, self).__init__()
26
+ if padding is None:
27
+ assert(kernel_size % 2 == 1)
28
+ padding = int(dilation * (kernel_size - 1) / 2)
29
+
30
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
31
+ kernel_size=kernel_size, stride=stride,
32
+ padding=padding, dilation=dilation,
33
+ bias=bias)
34
+
35
+ torch.nn.init.xavier_uniform_(
36
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
37
+
38
+ def forward(self, signal):
39
+ conv_signal = self.conv(signal)
40
+ return conv_signal
41
+
42
+
43
+ class GlobalAvgPool(torch.nn.Module):
44
+ def __init__(self):
45
+ super(GlobalAvgPool, self).__init__()
46
+
47
+ def forward(self, x, lengths=None):
48
+ """Average pooling across time steps (dim=1) with optionally lengths.
49
+ Args:
50
+ x: torch.Tensor of shape (N, T, ...)
51
+ lengths: None or torch.Tensor of shape (N,)
52
+ dim: dimension to pool
53
+ """
54
+ if lengths is None:
55
+ return x.mean(dim=1, keepdim=False)
56
+ else:
57
+ mask = get_mask_from_lengths(lengths).type(x.type()).to(x.device)
58
+ mask_shape = list(mask.size()) + [1 for _ in range(x.ndimension()-2)]
59
+ mask = mask.reshape(*mask_shape)
60
+ numer = (x * mask).sum(dim=1, keepdim=False)
61
+ denom = mask.sum(dim=1, keepdim=False)
62
+ return numer / denom
63
+
64
+
65
+ class TacotronSTFT(torch.nn.Module):
66
+ def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
67
+ n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
68
+ mel_fmax=8000.0):
69
+ super(TacotronSTFT, self).__init__()
70
+ self.n_mel_channels = n_mel_channels
71
+ self.sampling_rate = sampling_rate
72
+ self.stft_fn = STFT(filter_length, hop_length, win_length)
73
+ mel_basis = librosa_mel_fn(
74
+ sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
75
+ mel_basis = torch.from_numpy(mel_basis).float()
76
+ self.register_buffer('mel_basis', mel_basis)
77
+
78
+ def spectral_normalize(self, magnitudes):
79
+ output = dynamic_range_compression(magnitudes)
80
+ return output
81
+
82
+ def spectral_de_normalize(self, magnitudes):
83
+ output = dynamic_range_decompression(magnitudes)
84
+ return output
85
+
86
+ def mel_spectrogram(self, y):
87
+ """Computes mel-spectrograms from a batch of waves
88
+ PARAMS
89
+ ------
90
+ y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
91
+
92
+ RETURNS
93
+ -------
94
+ mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
95
+ """
96
+ assert(torch.min(y.data) >= -1)
97
+ assert(torch.max(y.data) <= 1)
98
+
99
+ magnitudes, phases = self.stft_fn.transform(y)
100
+ magnitudes = magnitudes.data
101
+ mel_output = torch.matmul(self.mel_basis, magnitudes)
102
+ mel_output = self.spectral_normalize(mel_output)
103
+ return mel_output
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/model.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import sqrt
2
+ import torch
3
+ import torch.distributions as distr
4
+ from torch.autograd import Variable
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from .layers import ConvNorm, LinearNorm, GlobalAvgPool
8
+ from .utils import to_gpu, get_mask_from_lengths
9
+
10
+
11
+ class LocationLayer(nn.Module):
12
+ def __init__(self, attention_n_filters, attention_kernel_size,
13
+ attention_dim):
14
+ super(LocationLayer, self).__init__()
15
+ padding = int((attention_kernel_size - 1) / 2)
16
+ self.location_conv = ConvNorm(2, attention_n_filters,
17
+ kernel_size=attention_kernel_size,
18
+ padding=padding, bias=False, stride=1,
19
+ dilation=1)
20
+ self.location_dense = LinearNorm(attention_n_filters, attention_dim,
21
+ bias=False, w_init_gain='tanh')
22
+
23
+ def forward(self, attention_weights_cat):
24
+ processed_attention = self.location_conv(attention_weights_cat)
25
+ processed_attention = processed_attention.transpose(1, 2)
26
+ processed_attention = self.location_dense(processed_attention)
27
+ return processed_attention
28
+
29
+
30
+ class Attention(nn.Module):
31
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
32
+ attention_location_n_filters, attention_location_kernel_size):
33
+ super(Attention, self).__init__()
34
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
35
+ bias=False, w_init_gain='tanh')
36
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
37
+ w_init_gain='tanh')
38
+ self.v = LinearNorm(attention_dim, 1, bias=False)
39
+ self.location_layer = LocationLayer(attention_location_n_filters,
40
+ attention_location_kernel_size,
41
+ attention_dim)
42
+ self.score_mask_value = -float("inf")
43
+
44
+ def get_alignment_energies(self, query, processed_memory,
45
+ attention_weights_cat):
46
+ """
47
+ PARAMS
48
+ ------
49
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
50
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
51
+ attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
52
+
53
+ RETURNS
54
+ -------
55
+ alignment (batch, max_time)
56
+ """
57
+
58
+ processed_query = self.query_layer(query.unsqueeze(1))
59
+ processed_attention_weights = self.location_layer(attention_weights_cat)
60
+ energies = self.v(torch.tanh(
61
+ processed_query + processed_attention_weights + processed_memory))
62
+
63
+ energies = energies.squeeze(-1)
64
+ return energies
65
+
66
+ def forward(self, attention_hidden_state, memory, processed_memory,
67
+ attention_weights_cat, mask):
68
+ """
69
+ PARAMS
70
+ ------
71
+ attention_hidden_state: attention rnn last output
72
+ memory: encoder outputs
73
+ processed_memory: processed encoder outputs
74
+ attention_weights_cat: previous and cummulative attention weights
75
+ mask: binary mask for padded data
76
+ """
77
+ alignment = self.get_alignment_energies(
78
+ attention_hidden_state, processed_memory, attention_weights_cat)
79
+
80
+ if mask is not None:
81
+ alignment.data.masked_fill_(mask, self.score_mask_value)
82
+
83
+ attention_weights = F.softmax(alignment, dim=1)
84
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
85
+ attention_context = attention_context.squeeze(1)
86
+
87
+ return attention_context, attention_weights
88
+
89
+
90
+ class Prenet(nn.Module):
91
+ def __init__(self, in_dim, sizes):
92
+ super(Prenet, self).__init__()
93
+ in_sizes = [in_dim] + sizes[:-1]
94
+ self.layers = nn.ModuleList(
95
+ [LinearNorm(in_size, out_size, bias=False)
96
+ for (in_size, out_size) in zip(in_sizes, sizes)])
97
+
98
+ def forward(self, x):
99
+ for linear in self.layers:
100
+ x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
101
+ return x
102
+
103
+
104
+ class Postnet(nn.Module):
105
+ """Postnet
106
+ - Five 1-d convolution with 512 channels and kernel size 5
107
+ """
108
+
109
+ def __init__(self, hparams):
110
+ super(Postnet, self).__init__()
111
+ self.convolutions = nn.ModuleList()
112
+
113
+ self.convolutions.append(
114
+ nn.Sequential(
115
+ ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim,
116
+ kernel_size=hparams.postnet_kernel_size, stride=1,
117
+ padding=int((hparams.postnet_kernel_size - 1) / 2),
118
+ dilation=1, w_init_gain='tanh'),
119
+ nn.BatchNorm1d(hparams.postnet_embedding_dim))
120
+ )
121
+
122
+ for i in range(1, hparams.postnet_n_convolutions - 1):
123
+ self.convolutions.append(
124
+ nn.Sequential(
125
+ ConvNorm(hparams.postnet_embedding_dim,
126
+ hparams.postnet_embedding_dim,
127
+ kernel_size=hparams.postnet_kernel_size, stride=1,
128
+ padding=int((hparams.postnet_kernel_size - 1) / 2),
129
+ dilation=1, w_init_gain='tanh'),
130
+ nn.BatchNorm1d(hparams.postnet_embedding_dim))
131
+ )
132
+
133
+ self.convolutions.append(
134
+ nn.Sequential(
135
+ ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels,
136
+ kernel_size=hparams.postnet_kernel_size, stride=1,
137
+ padding=int((hparams.postnet_kernel_size - 1) / 2),
138
+ dilation=1, w_init_gain='linear'),
139
+ nn.BatchNorm1d(hparams.n_mel_channels))
140
+ )
141
+
142
+ def forward(self, x):
143
+ for i in range(len(self.convolutions) - 1):
144
+ x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
145
+ x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
146
+
147
+ return x
148
+
149
+
150
+ class Encoder(nn.Module):
151
+ """Encoder module:
152
+ - Three 1-d convolution banks
153
+ - Bidirectional LSTM
154
+ """
155
+ def __init__(self, hparams):
156
+ super(Encoder, self).__init__()
157
+
158
+ convolutions = []
159
+ for _ in range(hparams.encoder_n_convolutions):
160
+ conv_layer = nn.Sequential(
161
+ ConvNorm(hparams.encoder_embedding_dim,
162
+ hparams.encoder_embedding_dim,
163
+ kernel_size=hparams.encoder_kernel_size, stride=1,
164
+ padding=int((hparams.encoder_kernel_size - 1) / 2),
165
+ dilation=1, w_init_gain='relu'),
166
+ nn.BatchNorm1d(hparams.encoder_embedding_dim))
167
+ convolutions.append(conv_layer)
168
+ self.convolutions = nn.ModuleList(convolutions)
169
+
170
+ self.lstm = nn.LSTM(hparams.encoder_embedding_dim,
171
+ int(hparams.encoder_embedding_dim / 2), 1,
172
+ batch_first=True, bidirectional=True)
173
+
174
+ def forward(self, x, input_lengths):
175
+ for conv in self.convolutions:
176
+ x = F.dropout(F.relu(conv(x)), 0.5, self.training)
177
+
178
+ x = x.transpose(1, 2)
179
+
180
+ # pytorch tensor are not reversible, hence the conversion
181
+ input_lengths = input_lengths.cpu().numpy()
182
+ x = nn.utils.rnn.pack_padded_sequence(
183
+ x, input_lengths, batch_first=True)
184
+
185
+ self.lstm.flatten_parameters()
186
+ outputs, _ = self.lstm(x)
187
+
188
+ outputs, _ = nn.utils.rnn.pad_packed_sequence(
189
+ outputs, batch_first=True)
190
+
191
+ return outputs
192
+
193
+ def inference(self, x):
194
+ for conv in self.convolutions:
195
+ x = F.dropout(F.relu(conv(x)), 0.5, self.training)
196
+
197
+ x = x.transpose(1, 2)
198
+
199
+ self.lstm.flatten_parameters()
200
+ outputs, _ = self.lstm(x)
201
+
202
+ return outputs
203
+
204
+
205
+ class AudioEncoder(nn.Module):
206
+ def __init__(self, hparams):
207
+ super(AudioEncoder, self).__init__()
208
+
209
+ assert hparams.lat_dim > 0
210
+
211
+ convolutions = []
212
+ inp_dim = hparams.n_mel_channels
213
+ for _ in range(hparams.lat_n_convolutions):
214
+ conv_layer = nn.Sequential(
215
+ ConvNorm(inp_dim, hparams.lat_n_filters,
216
+ kernel_size=hparams.lat_kernel_size, stride=1,
217
+ padding=int((hparams.lat_kernel_size - 1) / 2),
218
+ dilation=1, w_init_gain='tanh'),
219
+ nn.BatchNorm1d(hparams.lat_n_filters))
220
+ inp_dim = hparams.lat_n_filters
221
+ convolutions.append(conv_layer)
222
+ self.convolutions = nn.ModuleList(convolutions)
223
+
224
+ self.lstm = nn.LSTM(hparams.lat_n_filters,
225
+ int(hparams.lat_n_filters / 2),
226
+ hparams.lat_n_blstms, batch_first=True,
227
+ bidirectional=True)
228
+ self.pool = GlobalAvgPool()
229
+
230
+ self.mu_proj = LinearNorm(hparams.lat_n_filters, hparams.lat_dim)
231
+ self.logvar_proj = LinearNorm(hparams.lat_n_filters, hparams.lat_dim)
232
+ self.lat_dim = hparams.lat_dim
233
+
234
+ def forward(self, x, lengths):
235
+ """
236
+ Args:
237
+ x (torch.Tensor): (B, F, T)
238
+ """
239
+
240
+ for conv in self.convolutions:
241
+ x = F.dropout(F.tanh(conv(x)), 0.5, self.training)
242
+
243
+ x = x.transpose(1, 2) # (B, T, D)
244
+
245
+ # x may not be sorted by length. Sort->process->unsort
246
+ max_len = x.size(1)
247
+ assert max_len == torch.max(lengths).item()
248
+
249
+ lengths, perm_idx = lengths.sort(0, descending=True)
250
+ x = x[perm_idx]
251
+ x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True)
252
+
253
+ self.lstm.flatten_parameters()
254
+ outputs, _ = self.lstm(x)
255
+ outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
256
+
257
+ _, unperm_idx = perm_idx.sort(0)
258
+ outputs = outputs[unperm_idx] # (B, T, D)
259
+ lengths = lengths[unperm_idx] # (B, T, D)
260
+
261
+ outputs = self.pool(outputs, lengths) # (B, D)
262
+
263
+ mu = self.mu_proj(outputs)
264
+ logvar = self.logvar_proj(outputs)
265
+ z = distr.Normal(mu, logvar).rsample()
266
+ return z, mu, logvar
267
+
268
+
269
+ class Decoder(nn.Module):
270
+ def __init__(self, hparams):
271
+ super(Decoder, self).__init__()
272
+ self.n_mel_channels = hparams.n_mel_channels
273
+ self.n_frames_per_step = hparams.n_frames_per_step
274
+ self.encoder_embedding_dim = hparams.encoder_embedding_dim
275
+ self.obs_dim = hparams.obs_dim
276
+ self.lat_dim = hparams.lat_dim
277
+ self.attention_rnn_dim = hparams.attention_rnn_dim
278
+ self.decoder_rnn_dim = hparams.decoder_rnn_dim
279
+ self.prenet_dim = hparams.prenet_dim
280
+ self.max_decoder_steps = hparams.max_decoder_steps
281
+ self.gate_threshold = hparams.gate_threshold
282
+ self.p_attention_dropout = hparams.p_attention_dropout
283
+ self.p_decoder_dropout = hparams.p_decoder_dropout
284
+
285
+ self.prenet = Prenet(
286
+ hparams.n_mel_channels * hparams.n_frames_per_step,
287
+ [hparams.prenet_dim, hparams.prenet_dim])
288
+
289
+ self.attention_rnn = nn.LSTMCell(
290
+ hparams.prenet_dim + hparams.encoder_embedding_dim,
291
+ hparams.attention_rnn_dim)
292
+
293
+ self.attention_layer = Attention(
294
+ hparams.attention_rnn_dim, hparams.encoder_embedding_dim,
295
+ hparams.attention_dim, hparams.attention_location_n_filters,
296
+ hparams.attention_location_kernel_size)
297
+
298
+ encoder_tot_dim = (hparams.encoder_embedding_dim + \
299
+ hparams.lat_dim + hparams.obs_dim)
300
+ self.decoder_rnn = nn.LSTMCell(
301
+ hparams.attention_rnn_dim + encoder_tot_dim,
302
+ hparams.decoder_rnn_dim, 1)
303
+
304
+ self.linear_projection = LinearNorm(
305
+ hparams.decoder_rnn_dim + encoder_tot_dim,
306
+ hparams.n_mel_channels * hparams.n_frames_per_step)
307
+
308
+ self.gate_layer = LinearNorm(
309
+ hparams.decoder_rnn_dim + encoder_tot_dim, 1,
310
+ bias=True, w_init_gain='sigmoid')
311
+
312
+ def get_go_frame(self, memory):
313
+ """ Gets all zeros frames to use as first decoder input
314
+ PARAMS
315
+ ------
316
+ memory: decoder outputs
317
+
318
+ RETURNS
319
+ -------
320
+ decoder_input: all zeros frames
321
+ """
322
+ B = memory.size(0)
323
+ decoder_input = Variable(memory.data.new(
324
+ B, self.n_mel_channels * self.n_frames_per_step).zero_())
325
+ return decoder_input
326
+
327
+ def initialize_decoder_states(self, memory, obs_and_lat, mask):
328
+ """ Initializes attention rnn states, decoder rnn states, attention
329
+ weights, attention cumulative weights, attention context, stores memory
330
+ and stores processed memory
331
+ PARAMS
332
+ ------
333
+ memory: Encoder outputs
334
+ obs_and_lat: Observed and latent attribute embeddings
335
+ mask: Mask for padded data if training, expects None for inference
336
+ """
337
+ B = memory.size(0)
338
+ MAX_TIME = memory.size(1)
339
+
340
+ self.attention_hidden = Variable(memory.data.new(
341
+ B, self.attention_rnn_dim).zero_())
342
+ self.attention_cell = Variable(memory.data.new(
343
+ B, self.attention_rnn_dim).zero_())
344
+
345
+ self.decoder_hidden = Variable(memory.data.new(
346
+ B, self.decoder_rnn_dim).zero_())
347
+ self.decoder_cell = Variable(memory.data.new(
348
+ B, self.decoder_rnn_dim).zero_())
349
+
350
+ self.attention_weights = Variable(memory.data.new(
351
+ B, MAX_TIME).zero_())
352
+ self.attention_weights_cum = Variable(memory.data.new(
353
+ B, MAX_TIME).zero_())
354
+ self.attention_context = Variable(memory.data.new(
355
+ B, self.encoder_embedding_dim).zero_())
356
+
357
+ self.memory = memory
358
+ self.processed_memory = self.attention_layer.memory_layer(memory)
359
+ self.obs_and_lat = obs_and_lat
360
+ self.mask = mask
361
+
362
+ def parse_decoder_inputs(self, decoder_inputs):
363
+ """ Prepares decoder inputs, i.e. mel outputs
364
+ PARAMS
365
+ ------
366
+ decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs
367
+
368
+ RETURNS
369
+ -------
370
+ inputs: processed decoder inputs
371
+
372
+ """
373
+ # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels)
374
+ decoder_inputs = decoder_inputs.transpose(1, 2)
375
+ decoder_inputs = decoder_inputs.view(
376
+ decoder_inputs.size(0),
377
+ int(decoder_inputs.size(1)/self.n_frames_per_step), -1)
378
+ # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels)
379
+ decoder_inputs = decoder_inputs.transpose(0, 1)
380
+ return decoder_inputs
381
+
382
+ def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):
383
+ """ Prepares decoder outputs for output
384
+ PARAMS
385
+ ------
386
+ mel_outputs:
387
+ gate_outputs: gate output energies
388
+ alignments:
389
+
390
+ RETURNS
391
+ -------
392
+ mel_outputs:
393
+ gate_outpust: gate output energies
394
+ alignments:
395
+ """
396
+ # (T_out, B) -> (B, T_out)
397
+ alignments = torch.stack(alignments).transpose(0, 1)
398
+ # (T_out, B) -> (B, T_out)
399
+ gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
400
+ gate_outputs = gate_outputs.contiguous()
401
+ # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
402
+ mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
403
+ # decouple frames per step
404
+ mel_outputs = mel_outputs.view(
405
+ mel_outputs.size(0), -1, self.n_mel_channels)
406
+ # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
407
+ mel_outputs = mel_outputs.transpose(1, 2)
408
+
409
+ return mel_outputs, gate_outputs, alignments
410
+
411
+ def decode(self, decoder_input):
412
+ """ Decoder step using stored states, attention and memory
413
+ PARAMS
414
+ ------
415
+ decoder_input: previous mel output
416
+
417
+ RETURNS
418
+ -------
419
+ mel_output:
420
+ gate_output: gate output energies
421
+ attention_weights:
422
+ """
423
+ cell_input = torch.cat((decoder_input, self.attention_context), -1)
424
+ self.attention_hidden, self.attention_cell = self.attention_rnn(
425
+ cell_input, (self.attention_hidden, self.attention_cell))
426
+ self.attention_hidden = F.dropout(
427
+ self.attention_hidden, self.p_attention_dropout, self.training)
428
+
429
+ attention_weights_cat = torch.cat(
430
+ (self.attention_weights.unsqueeze(1),
431
+ self.attention_weights_cum.unsqueeze(1)), dim=1)
432
+ self.attention_context, self.attention_weights = self.attention_layer(
433
+ self.attention_hidden, self.memory, self.processed_memory,
434
+ attention_weights_cat, self.mask)
435
+
436
+ self.attention_weights_cum += self.attention_weights
437
+ decoder_input = torch.cat(
438
+ (self.attention_hidden, self.attention_context), -1)
439
+ if self.obs_and_lat is not None:
440
+ decoder_input = torch.cat((decoder_input, self.obs_and_lat), -1)
441
+ self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
442
+ decoder_input, (self.decoder_hidden, self.decoder_cell))
443
+ self.decoder_hidden = F.dropout(
444
+ self.decoder_hidden, self.p_decoder_dropout, self.training)
445
+
446
+ decoder_hidden_attention_context = torch.cat(
447
+ (self.decoder_hidden, self.attention_context), dim=1)
448
+ if self.obs_and_lat is not None:
449
+ decoder_hidden_attention_context = torch.cat(
450
+ (decoder_hidden_attention_context, self.obs_and_lat), dim=1)
451
+ decoder_output = self.linear_projection(
452
+ decoder_hidden_attention_context)
453
+
454
+ gate_prediction = self.gate_layer(decoder_hidden_attention_context)
455
+ return decoder_output, gate_prediction, self.attention_weights
456
+
457
+ def forward(self, memory, obs_and_lat, decoder_inputs, memory_lengths):
458
+ """ Decoder forward pass for training
459
+ PARAMS
460
+ ------
461
+ memory: Encoder outputs
462
+ obs_and_lat: Observed and latent attribute embeddings
463
+ decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
464
+ memory_lengths: Encoder output lengths for attention masking.
465
+
466
+ RETURNS
467
+ -------
468
+ mel_outputs: mel outputs from the decoder
469
+ gate_outputs: gate outputs from the decoder
470
+ alignments: sequence of attention weights from the decoder
471
+ """
472
+
473
+ decoder_input = self.get_go_frame(memory).unsqueeze(0)
474
+ decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
475
+ decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
476
+ decoder_inputs = self.prenet(decoder_inputs)
477
+
478
+ self.initialize_decoder_states(
479
+ memory, obs_and_lat, mask=~get_mask_from_lengths(memory_lengths))
480
+
481
+ mel_outputs, gate_outputs, alignments = [], [], []
482
+ while len(mel_outputs) < decoder_inputs.size(0) - 1:
483
+ decoder_input = decoder_inputs[len(mel_outputs)]
484
+ mel_output, gate_output, attention_weights = self.decode(
485
+ decoder_input)
486
+ mel_outputs += [mel_output.squeeze(1)]
487
+ gate_outputs += [gate_output.squeeze()]
488
+ alignments += [attention_weights]
489
+
490
+ mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
491
+ mel_outputs, gate_outputs, alignments)
492
+
493
+ return mel_outputs, gate_outputs, alignments
494
+
495
+ def inference(self, memory, obs_and_lat, ret_has_eos=False):
496
+ """ Decoder inference
497
+ PARAMS
498
+ ------
499
+ memory: Encoder outputs
500
+ obs_and_lat: Observed and latent attribute embeddings
501
+
502
+ RETURNS
503
+ -------
504
+ mel_outputs: mel outputs from the decoder
505
+ gate_outputs: gate outputs from the decoder
506
+ alignments: sequence of attention weights from the decoder
507
+ """
508
+ decoder_input = self.get_go_frame(memory)
509
+
510
+ self.initialize_decoder_states(memory, obs_and_lat, mask=None)
511
+
512
+ mel_outputs, gate_outputs, alignments = [], [], []
513
+ has_eos = False
514
+ while True:
515
+ decoder_input = self.prenet(decoder_input)
516
+ mel_output, gate_output, alignment = self.decode(decoder_input)
517
+
518
+ mel_outputs += [mel_output.squeeze(1)]
519
+ gate_outputs += [gate_output]
520
+ alignments += [alignment]
521
+
522
+ if torch.sigmoid(gate_output.data) > self.gate_threshold:
523
+ has_eos = True
524
+ break
525
+ elif len(mel_outputs) == self.max_decoder_steps:
526
+ # print("Warning! Reached max decoder steps")
527
+ break
528
+
529
+ decoder_input = mel_output
530
+
531
+ mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
532
+ mel_outputs, gate_outputs, alignments)
533
+
534
+ if ret_has_eos:
535
+ return mel_outputs, gate_outputs, alignments, has_eos
536
+ else:
537
+ return mel_outputs, gate_outputs, alignments
538
+
539
+
540
+ class Tacotron2(nn.Module):
541
+ def __init__(self, hparams):
542
+ super(Tacotron2, self).__init__()
543
+ self.mask_padding = hparams.mask_padding
544
+ self.fp16_run = hparams.fp16_run
545
+ self.n_mel_channels = hparams.n_mel_channels
546
+ self.n_frames_per_step = hparams.n_frames_per_step
547
+
548
+ # initialize text encoder embedding
549
+ self.embedding = nn.Embedding(
550
+ hparams.n_symbols, hparams.symbols_embedding_dim)
551
+ std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
552
+ val = sqrt(3.0) * std # uniform bounds for std
553
+ self.embedding.weight.data.uniform_(-val, val)
554
+
555
+ # initialize observed attribute embedding
556
+ self.obs_embedding = None
557
+ if hparams.obs_dim > 0:
558
+ self.obs_embedding = nn.Embedding(
559
+ hparams.obs_n_class, hparams.obs_dim)
560
+ std = sqrt(2.0 / (hparams.obs_n_class + hparams.obs_dim))
561
+ val = sqrt(3.0) * std # uniform bounds for std
562
+ self.obs_embedding.weight.data.uniform_(-val, val)
563
+
564
+ self.encoder = Encoder(hparams)
565
+ self.decoder = Decoder(hparams)
566
+ self.postnet = Postnet(hparams)
567
+
568
+ self.lat_encoder = None
569
+ if hparams.lat_dim > 0:
570
+ self.lat_encoder = AudioEncoder(hparams)
571
+
572
+ def parse_batch(self, batch):
573
+ (text_padded, input_lengths, obs_labels,
574
+ mel_padded, gate_padded, output_lengths) = batch
575
+ text_padded = to_gpu(text_padded).long()
576
+ input_lengths = to_gpu(input_lengths).long()
577
+ obs_labels = to_gpu(obs_labels).long()
578
+ max_len = torch.max(input_lengths.data).item()
579
+ mel_padded = to_gpu(mel_padded).float()
580
+ gate_padded = to_gpu(gate_padded).float()
581
+ output_lengths = to_gpu(output_lengths).long()
582
+
583
+ return (
584
+ (text_padded, input_lengths, obs_labels,
585
+ mel_padded, max_len, output_lengths),
586
+ (mel_padded, gate_padded))
587
+
588
+ def parse_output(self, outputs, output_lengths=None):
589
+ if self.mask_padding and output_lengths is not None:
590
+ mask = ~get_mask_from_lengths(output_lengths)
591
+ mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
592
+ mask = mask.permute(1, 0, 2)
593
+
594
+ outputs[0].data.masked_fill_(mask, 0.0)
595
+ outputs[1].data.masked_fill_(mask, 0.0)
596
+ outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
597
+
598
+ return outputs
599
+
600
+ def forward(self, inputs):
601
+ (text_inputs, text_lengths, obs_labels,
602
+ mels, max_len, output_lengths) = inputs
603
+ text_lengths, output_lengths = text_lengths.data, output_lengths.data
604
+
605
+ embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
606
+
607
+ encoder_outputs = self.encoder(embedded_inputs, text_lengths)
608
+
609
+ obs = None
610
+ if self.obs_embedding is not None:
611
+ obs = self.obs_embedding(obs_labels)
612
+
613
+ lat, lat_mu, lat_logvar = None, None, None
614
+ if self.lat_encoder is not None:
615
+ (lat, lat_mu, lat_logvar) = self.lat_encoder(mels, output_lengths)
616
+
617
+ obs_and_lat = [x for x in [obs, lat] if x is not None]
618
+ if bool(obs_and_lat):
619
+ obs_and_lat = torch.cat(obs_and_lat, dim=-1)
620
+ else:
621
+ obs_and_lat = None
622
+
623
+ mel_outputs, gate_outputs, alignments = self.decoder(
624
+ encoder_outputs, obs_and_lat, mels, memory_lengths=text_lengths)
625
+
626
+ mel_outputs_postnet = self.postnet(mel_outputs)
627
+ mel_outputs_postnet = mel_outputs + mel_outputs_postnet
628
+
629
+ return self.parse_output(
630
+ [mel_outputs, mel_outputs_postnet, gate_outputs, alignments,
631
+ lat_mu, lat_logvar],
632
+ output_lengths)
633
+
634
+ def inference(self, inputs, obs_labels=None, lat=None, ret_has_eos=False):
635
+ embedded_inputs = self.embedding(inputs).transpose(1, 2)
636
+ encoder_outputs = self.encoder.inference(embedded_inputs)
637
+
638
+ if obs_labels is None:
639
+ obs_labels = torch.LongTensor(len(inputs))
640
+ obs_labels = obs_labels.to(inputs.device).zero_()
641
+
642
+ obs = None
643
+ if self.obs_embedding is not None:
644
+ obs = self.obs_embedding(obs_labels)
645
+
646
+ if self.lat_encoder is not None:
647
+ if lat is None:
648
+ lat = torch.FloatTensor(len(inputs), self.lat_encoder.lat_dim)
649
+ lat = lat.to(inputs.device).zero_().type(encoder_outputs.type())
650
+
651
+ obs_and_lat = [x for x in [obs, lat] if x is not None]
652
+ if bool(obs_and_lat):
653
+ obs_and_lat = torch.cat(obs_and_lat, dim=-1)
654
+ else:
655
+ obs_and_lat = None
656
+
657
+ mel_outputs, gate_outputs, alignments, has_eos = self.decoder.inference(
658
+ encoder_outputs, obs_and_lat, ret_has_eos=True)
659
+
660
+ mel_outputs_postnet = self.postnet(mel_outputs)
661
+ mel_outputs_postnet = mel_outputs + mel_outputs_postnet
662
+
663
+ outputs = self.parse_output(
664
+ [mel_outputs, mel_outputs_postnet, gate_outputs, alignments])
665
+
666
+ if ret_has_eos:
667
+ return outputs + [has_eos]
668
+ else:
669
+ return outputs
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/numbers.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ import inflect
4
+ import re
5
+
6
+
7
+ _inflect = inflect.engine()
8
+ _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
9
+ _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
10
+ _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
11
+ _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
12
+ _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
13
+ _number_re = re.compile(r'[0-9]+')
14
+
15
+
16
+ def _remove_commas(m):
17
+ return m.group(1).replace(',', '')
18
+
19
+
20
+ def _expand_decimal_point(m):
21
+ return m.group(1).replace('.', ' point ')
22
+
23
+
24
+ def _expand_dollars(m):
25
+ match = m.group(1)
26
+ parts = match.split('.')
27
+ if len(parts) > 2:
28
+ return match + ' dollars' # Unexpected format
29
+ dollars = int(parts[0]) if parts[0] else 0
30
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
31
+ if dollars and cents:
32
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
33
+ cent_unit = 'cent' if cents == 1 else 'cents'
34
+ return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
35
+ elif dollars:
36
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
37
+ return '%s %s' % (dollars, dollar_unit)
38
+ elif cents:
39
+ cent_unit = 'cent' if cents == 1 else 'cents'
40
+ return '%s %s' % (cents, cent_unit)
41
+ else:
42
+ return 'zero dollars'
43
+
44
+
45
+ def _expand_ordinal(m):
46
+ return _inflect.number_to_words(m.group(0))
47
+
48
+
49
+ def _expand_number(m):
50
+ num = int(m.group(0))
51
+ if num > 1000 and num < 3000:
52
+ if num == 2000:
53
+ return 'two thousand'
54
+ elif num > 2000 and num < 2010:
55
+ return 'two thousand ' + _inflect.number_to_words(num % 100)
56
+ elif num % 100 == 0:
57
+ return _inflect.number_to_words(num // 100) + ' hundred'
58
+ else:
59
+ return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
60
+ else:
61
+ return _inflect.number_to_words(num, andword='')
62
+
63
+
64
+ def normalize_numbers(text):
65
+ text = re.sub(_comma_number_re, _remove_commas, text)
66
+ text = re.sub(_pounds_re, r'\1 pounds', text)
67
+ text = re.sub(_dollars_re, _expand_dollars, text)
68
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
69
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
70
+ text = re.sub(_number_re, _expand_number, text)
71
+ return text
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/stft.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BSD 3-Clause License
3
+
4
+ Copyright (c) 2017, Prem Seetharaman
5
+ All rights reserved.
6
+
7
+ * Redistribution and use in source and binary forms, with or without
8
+ modification, are permitted provided that the following conditions are met:
9
+
10
+ * Redistributions of source code must retain the above copyright notice,
11
+ this list of conditions and the following disclaimer.
12
+
13
+ * Redistributions in binary form must reproduce the above copyright notice, this
14
+ list of conditions and the following disclaimer in the
15
+ documentation and/or other materials provided with the distribution.
16
+
17
+ * Neither the name of the copyright holder nor the names of its
18
+ contributors may be used to endorse or promote products derived from this
19
+ software without specific prior written permission.
20
+
21
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ """
32
+
33
+ import torch
34
+ import numpy as np
35
+ import torch.nn.functional as F
36
+ from torch.autograd import Variable
37
+ from scipy.signal import get_window
38
+ from librosa.util import pad_center, tiny
39
+ from .audio_processing import window_sumsquare
40
+
41
+
42
+ class STFT(torch.nn.Module):
43
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
44
+ def __init__(self, filter_length=800, hop_length=200, win_length=800,
45
+ window='hann'):
46
+ super(STFT, self).__init__()
47
+ self.filter_length = filter_length
48
+ self.hop_length = hop_length
49
+ self.win_length = win_length
50
+ self.window = window
51
+ self.forward_transform = None
52
+ scale = self.filter_length / self.hop_length
53
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
54
+
55
+ cutoff = int((self.filter_length / 2 + 1))
56
+ fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
57
+ np.imag(fourier_basis[:cutoff, :])])
58
+
59
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
60
+ inverse_basis = torch.FloatTensor(
61
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :])
62
+
63
+ if window is not None:
64
+ assert(filter_length >= win_length)
65
+ # get window and zero center pad it to filter_length
66
+ fft_window = get_window(window, win_length, fftbins=True)
67
+ fft_window = pad_center(fft_window, filter_length)
68
+ fft_window = torch.from_numpy(fft_window).float()
69
+
70
+ # window the bases
71
+ forward_basis *= fft_window
72
+ inverse_basis *= fft_window
73
+
74
+ self.register_buffer('forward_basis', forward_basis.float())
75
+ self.register_buffer('inverse_basis', inverse_basis.float())
76
+
77
+ def transform(self, input_data):
78
+ num_batches = input_data.size(0)
79
+ num_samples = input_data.size(1)
80
+
81
+ self.num_samples = num_samples
82
+
83
+ # similar to librosa, reflect-pad the input
84
+ input_data = input_data.view(num_batches, 1, num_samples)
85
+ input_data = F.pad(
86
+ input_data.unsqueeze(1),
87
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
88
+ mode='reflect')
89
+ input_data = input_data.squeeze(1)
90
+
91
+ forward_transform = F.conv1d(
92
+ input_data,
93
+ Variable(self.forward_basis, requires_grad=False),
94
+ stride=self.hop_length,
95
+ padding=0)
96
+
97
+ cutoff = int((self.filter_length / 2) + 1)
98
+ real_part = forward_transform[:, :cutoff, :]
99
+ imag_part = forward_transform[:, cutoff:, :]
100
+
101
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
102
+ phase = torch.autograd.Variable(
103
+ torch.atan2(imag_part.data, real_part.data))
104
+
105
+ return magnitude, phase
106
+
107
+ def inverse(self, magnitude, phase):
108
+ recombine_magnitude_phase = torch.cat(
109
+ [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
110
+
111
+ inverse_transform = F.conv_transpose1d(
112
+ recombine_magnitude_phase,
113
+ Variable(self.inverse_basis, requires_grad=False),
114
+ stride=self.hop_length,
115
+ padding=0)
116
+
117
+ if self.window is not None:
118
+ window_sum = window_sumsquare(
119
+ self.window, magnitude.size(-1), hop_length=self.hop_length,
120
+ win_length=self.win_length, n_fft=self.filter_length,
121
+ dtype=np.float32)
122
+ # remove modulation effects
123
+ approx_nonzero_indices = torch.from_numpy(
124
+ np.where(window_sum > tiny(window_sum))[0])
125
+ window_sum = torch.autograd.Variable(
126
+ torch.from_numpy(window_sum), requires_grad=False)
127
+ window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
128
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
129
+
130
+ # scale by hop ratio
131
+ inverse_transform *= float(self.filter_length) / self.hop_length
132
+
133
+ inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
134
+ inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
135
+
136
+ return inverse_transform
137
+
138
+ def forward(self, input_data):
139
+ self.magnitude, self.phase = self.transform(input_data)
140
+ reconstruction = self.inverse(self.magnitude, self.phase)
141
+ return reconstruction
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/symbols.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ '''
4
+ Defines the set of symbols used in text input to the model.
5
+
6
+ The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
7
+ from . import cmudict
8
+
9
+ _pad = '_'
10
+ _punctuation = '!\'(),.:;? '
11
+ _special = '-'
12
+ _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
13
+
14
+ # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
15
+ _arpabet = ['@' + s for s in cmudict.valid_symbols]
16
+
17
+ # Export all symbols:
18
+ symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/text.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+ import numpy as np
3
+ import re
4
+ from . import cleaners
5
+ from .symbols import symbols
6
+
7
+
8
+ # Mappings from symbol to numeric ID and vice versa:
9
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
10
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
11
+
12
+ # Regular expression matching text enclosed in curly braces:
13
+ _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
14
+
15
+ # Special symbols
16
+ SOS_TOK = '<s>'
17
+ EOS_TOK = '</s>'
18
+
19
+ def text_to_sequence(text, cleaner_names):
20
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
21
+
22
+ The text can optionally have ARPAbet sequences enclosed in curly braces embedded
23
+ in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
24
+
25
+ Args:
26
+ text: string to convert to a sequence
27
+ cleaner_names: names of the cleaner functions to run the text through
28
+
29
+ Returns:
30
+ List of integers corresponding to the symbols in the text
31
+ '''
32
+ sequence = []
33
+
34
+ # Check for curly braces and treat their contents as ARPAbet:
35
+ while len(text):
36
+ m = _curly_re.match(text)
37
+ if not m:
38
+ sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
39
+ break
40
+ sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
41
+ sequence += _arpabet_to_sequence(m.group(2))
42
+ text = m.group(3)
43
+
44
+ return sequence
45
+
46
+
47
+ def sample_code_chunk(code, size):
48
+ assert(size > 0 and size <= len(code))
49
+ start = np.random.randint(len(code) - size + 1)
50
+ end = start + size
51
+ return code[start:end], start, end
52
+
53
+
54
+ def code_to_sequence(code, code_dict, collapse_code):
55
+ if collapse_code:
56
+ prev_c = None
57
+ sequence = []
58
+ for c in code:
59
+ if c in code_dict and c != prev_c:
60
+ sequence.append(code_dict[c])
61
+ prev_c = c
62
+ else:
63
+ sequence = [code_dict[c] for c in code if c in code_dict]
64
+ if len(sequence) < 0.95 * len(code):
65
+ print('WARNING : over 5%% codes are OOV')
66
+
67
+ return sequence
68
+
69
+
70
+ def sequence_to_text(sequence):
71
+ '''Converts a sequence of IDs back to a string'''
72
+ result = ''
73
+ for symbol_id in sequence:
74
+ if symbol_id in _id_to_symbol:
75
+ s = _id_to_symbol[symbol_id]
76
+ # Enclose ARPAbet back in curly braces:
77
+ if len(s) > 1 and s[0] == '@':
78
+ s = '{%s}' % s[1:]
79
+ result += s
80
+ return result.replace('}{', ' ')
81
+
82
+
83
+ def sequence_to_code(sequence, code_dict):
84
+ '''Analogous to sequence_to_text'''
85
+ id_to_code = {i: c for c, i in code_dict.items()}
86
+ return ' '.join([id_to_code[i] for i in sequence])
87
+
88
+
89
+ def _clean_text(text, cleaner_names):
90
+ for name in cleaner_names:
91
+ cleaner = getattr(cleaners, name)
92
+ if not cleaner:
93
+ raise Exception('Unknown cleaner: %s' % name)
94
+ text = cleaner(text)
95
+ return text
96
+
97
+
98
+ def _symbols_to_sequence(symbols):
99
+ return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
100
+
101
+
102
+ def _arpabet_to_sequence(text):
103
+ return _symbols_to_sequence(['@' + s for s in text.split()])
104
+
105
+
106
+ def _should_keep_symbol(s):
107
+ return s in _symbol_to_id and s != '_' and s != '~'
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/utils.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import collections
7
+ import io
8
+ import json
9
+ import librosa
10
+ import numpy as np
11
+ import soundfile as sf
12
+ import time
13
+ import torch
14
+ from scipy.io.wavfile import read
15
+ from .text import SOS_TOK, EOS_TOK
16
+
17
+
18
+ def get_mask_from_lengths(lengths):
19
+ max_len = torch.max(lengths).item()
20
+ ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
21
+ mask = (ids < lengths.unsqueeze(1))
22
+ return mask
23
+
24
+
25
+ def load_wav_to_torch(full_path, sr=None):
26
+ data, sr = librosa.load(full_path, sr=sr)
27
+ data = np.clip(data, -1, 1) # potentially out of [-1, 1] due to resampling
28
+ data = data * 32768.0 # match values loaded by scipy
29
+ return torch.FloatTensor(data.astype(np.float32)), sr
30
+
31
+
32
+ def read_binary_audio(bin_data, tar_sr=None):
33
+ """
34
+ read binary audio (`bytes` or `uint8` `numpy.ndarray`) to `float32`
35
+ `numpy.ndarray`
36
+
37
+ RETURNS:
38
+ data (np.ndarray) : audio of shape (n,) or (2, n)
39
+ tar_sr (int) : sample rate
40
+ """
41
+ data, ori_sr = sf.read(io.BytesIO(bin_data), dtype='float32')
42
+ data = data.T
43
+ if (tar_sr is not None) and (ori_sr != tar_sr):
44
+ data = librosa.resample(data, ori_sr, tar_sr)
45
+ else:
46
+ tar_sr = ori_sr
47
+ data = np.clip(data, -1, 1)
48
+ data = data * 32768.0
49
+ return torch.FloatTensor(data.astype(np.float32)), tar_sr
50
+
51
+
52
+ def load_filepaths_and_text(filename):
53
+ with open(filename, encoding='utf-8') as f:
54
+ data = [json.loads(line.rstrip()) for line in f]
55
+ return data
56
+
57
+
58
+ def to_gpu(x):
59
+ x = x.contiguous()
60
+
61
+ if torch.cuda.is_available():
62
+ x = x.cuda(non_blocking=True)
63
+ return torch.autograd.Variable(x)
64
+
65
+
66
+ def load_code_dict(path, add_sos=False, add_eos=False):
67
+ if not path:
68
+ return {}
69
+
70
+ with open(path, 'r') as f:
71
+ codes = ['_'] + [line.rstrip() for line in f] # '_' for pad
72
+ code_dict = {c: i for i, c in enumerate(codes)}
73
+
74
+ if add_sos:
75
+ code_dict[SOS_TOK] = len(code_dict)
76
+ if add_eos:
77
+ code_dict[EOS_TOK] = len(code_dict)
78
+ assert(set(code_dict.values()) == set(range(len(code_dict))))
79
+
80
+ return code_dict
81
+
82
+
83
+ def load_obs_label_dict(path):
84
+ if not path:
85
+ return {}
86
+ with open(path, 'r') as f:
87
+ obs_labels = [line.rstrip() for line in f]
88
+ return {c: i for i, c in enumerate(obs_labels)}
89
+
90
+
91
+ # A simple timer class inspired from `tnt.TimeMeter`
92
+ class CudaTimer:
93
+ def __init__(self, keys):
94
+ self.keys = keys
95
+ self.reset()
96
+
97
+ def start(self, key):
98
+ s = torch.cuda.Event(enable_timing=True)
99
+ s.record()
100
+ self.start_events[key].append(s)
101
+ return self
102
+
103
+ def stop(self, key):
104
+ e = torch.cuda.Event(enable_timing=True)
105
+ e.record()
106
+ self.end_events[key].append(e)
107
+ return self
108
+
109
+ def reset(self):
110
+ self.start_events = collections.defaultdict(list)
111
+ self.end_events = collections.defaultdict(list)
112
+ self.running_times = collections.defaultdict(float)
113
+ self.n = collections.defaultdict(int)
114
+ return self
115
+
116
+ def value(self):
117
+ self._synchronize()
118
+ return {k: self.running_times[k] / self.n[k] for k in self.keys}
119
+
120
+ def _synchronize(self):
121
+ torch.cuda.synchronize()
122
+ for k in self.keys:
123
+ starts = self.start_events[k]
124
+ ends = self.end_events[k]
125
+ if len(starts) == 0:
126
+ raise ValueError("Trying to divide by zero in TimeMeter")
127
+ if len(ends) != len(starts):
128
+ raise ValueError("Call stop before checking value!")
129
+ time = 0
130
+ for start, end in zip(starts, ends):
131
+ time += start.elapsed_time(end)
132
+ self.running_times[k] += time * 1e-3
133
+ self.n[k] += len(starts)
134
+ self.start_events = collections.defaultdict(list)
135
+ self.end_events = collections.defaultdict(list)
136
+
137
+
138
+ # Used to measure the time taken for multiple events
139
+ class Timer:
140
+ def __init__(self, keys):
141
+ self.keys = keys
142
+ self.n = {}
143
+ self.running_time = {}
144
+ self.total_time = {}
145
+ self.reset()
146
+
147
+ def start(self, key):
148
+ self.running_time[key] = time.time()
149
+ return self
150
+
151
+ def stop(self, key):
152
+ self.total_time[key] = time.time() - self.running_time[key]
153
+ self.n[key] += 1
154
+ self.running_time[key] = None
155
+ return self
156
+
157
+ def reset(self):
158
+ for k in self.keys:
159
+ self.total_time[k] = 0
160
+ self.running_time[k] = None
161
+ self.n[k] = 0
162
+ return self
163
+
164
+ def value(self):
165
+ vals = {}
166
+ for k in self.keys:
167
+ if self.n[k] == 0:
168
+ raise ValueError("Trying to divide by zero in TimeMeter")
169
+ else:
170
+ vals[k] = self.total_time[k] / self.n[k]
171
+ return vals
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/waveglow_denoiser.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import sys
2
+ # sys.path.append('tacotron2')
3
+ import torch
4
+ from .layers import STFT
5
+
6
+
7
+ class Denoiser(torch.nn.Module):
8
+ """ Removes model bias from audio produced with waveglow """
9
+
10
+ def __init__(self, waveglow, filter_length=1024, n_overlap=4,
11
+ win_length=1024, mode='zeros'):
12
+ super(Denoiser, self).__init__()
13
+ self.stft = STFT(filter_length=filter_length,
14
+ hop_length=int(filter_length/n_overlap),
15
+ win_length=win_length).cuda()
16
+ if mode == 'zeros':
17
+ mel_input = torch.zeros(
18
+ (1, 80, 88),
19
+ dtype=waveglow.upsample.weight.dtype,
20
+ device=waveglow.upsample.weight.device)
21
+ elif mode == 'normal':
22
+ mel_input = torch.randn(
23
+ (1, 80, 88),
24
+ dtype=waveglow.upsample.weight.dtype,
25
+ device=waveglow.upsample.weight.device)
26
+ else:
27
+ raise Exception("Mode {} if not supported".format(mode))
28
+
29
+ with torch.no_grad():
30
+ bias_audio = waveglow.infer(mel_input, sigma=0.0).float()
31
+ bias_spec, _ = self.stft.transform(bias_audio)
32
+
33
+ self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None])
34
+
35
+ def forward(self, audio, strength=0.1):
36
+ audio_spec, audio_angles = self.stft.transform(audio.cuda().float())
37
+ audio_spec_denoised = audio_spec - self.bias_spec * strength
38
+ audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
39
+ audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles)
40
+ return audio_denoised
fairseq/examples/textless_nlp/gslm/unit2speech/tts_data.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import torch
8
+ import numpy as np
9
+ from examples.textless_nlp.gslm.unit2speech.tacotron2.text import (
10
+ EOS_TOK,
11
+ SOS_TOK,
12
+ code_to_sequence,
13
+ text_to_sequence,
14
+ )
15
+ from examples.textless_nlp.gslm.unit2speech.tacotron2.utils import (
16
+ load_code_dict,
17
+ )
18
+
19
+
20
+ class TacotronInputDataset:
21
+ def __init__(self, hparams, append_str=""):
22
+ self.is_text = getattr(hparams, "text_or_code", "text") == "text"
23
+ if not self.is_text:
24
+ self.code_dict = load_code_dict(
25
+ hparams.code_dict, hparams.add_sos, hparams.add_eos
26
+ )
27
+ self.code_key = hparams.code_key
28
+ self.add_sos = hparams.add_sos
29
+ self.add_eos = hparams.add_eos
30
+ self.collapse_code = hparams.collapse_code
31
+ self.append_str = append_str
32
+
33
+ def process_code(self, inp_str):
34
+ inp_toks = inp_str.split()
35
+ if self.add_sos:
36
+ inp_toks = [SOS_TOK] + inp_toks
37
+ if self.add_eos:
38
+ inp_toks = inp_toks + [EOS_TOK]
39
+ return code_to_sequence(inp_toks, self.code_dict, self.collapse_code)
40
+
41
+ def process_text(self, inp_str):
42
+ return text_to_sequence(inp_str, ["english_cleaners"])
43
+
44
+ def get_tensor(self, inp_str):
45
+ # uid, txt, inp_str = self._get_data(idx)
46
+ inp_str = inp_str + self.append_str
47
+ if self.is_text:
48
+ inp_toks = self.process_text(inp_str)
49
+ else:
50
+ inp_toks = self.process_code(inp_str)
51
+ return torch.from_numpy(np.array(inp_toks)).long()
52
+
53
+ def __len__(self):
54
+ return len(self.data)
fairseq/examples/textless_nlp/gslm/unit2speech/utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import torch
8
+ from examples.textless_nlp.gslm.unit2speech.tacotron2.model import Tacotron2
9
+ from examples.textless_nlp.gslm.unit2speech.tacotron2.waveglow_denoiser import (
10
+ Denoiser,
11
+ )
12
+
13
+
14
+ def load_quantized_audio_from_file(file_path):
15
+ base_fname_batch, quantized_units_batch = [], []
16
+ with open(file_path) as f:
17
+ for line in f:
18
+ base_fname, quantized_units_str = line.rstrip().split("|")
19
+ quantized_units = [int(q) for q in quantized_units_str.split(" ")]
20
+ base_fname_batch.append(base_fname)
21
+ quantized_units_batch.append(quantized_units)
22
+ return base_fname_batch, quantized_units_batch
23
+
24
+
25
+ def synthesize_audio(model, waveglow, denoiser, inp, lab=None, strength=0.0):
26
+ assert inp.size(0) == 1
27
+ inp = inp.cuda()
28
+ if lab is not None:
29
+ lab = torch.LongTensor(1).cuda().fill_(lab)
30
+
31
+ with torch.no_grad():
32
+ _, mel, _, ali, has_eos = model.inference(inp, lab, ret_has_eos=True)
33
+ aud = waveglow.infer(mel, sigma=0.666)
34
+ aud_dn = denoiser(aud, strength=strength).squeeze(1)
35
+ return mel, aud, aud_dn, has_eos
36
+
37
+
38
+ def load_tacotron(tacotron_model_path, max_decoder_steps):
39
+ ckpt_dict = torch.load(tacotron_model_path)
40
+ hparams = ckpt_dict["hparams"]
41
+ hparams.max_decoder_steps = max_decoder_steps
42
+ sr = hparams.sampling_rate
43
+ model = Tacotron2(hparams)
44
+ model.load_state_dict(ckpt_dict["model_dict"])
45
+ model = model.cuda().eval().half()
46
+ return model, sr, hparams
47
+
48
+
49
+ def load_waveglow(waveglow_path):
50
+ waveglow = torch.load(waveglow_path)["model"]
51
+ waveglow = waveglow.cuda().eval().half()
52
+ for k in waveglow.convinv:
53
+ k.float()
54
+ denoiser = Denoiser(waveglow)
55
+ return waveglow, denoiser
fairseq/examples/textless_nlp/pgslm/README.md ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Text-Free Prosody-Aware Generative Spoken Language Modeling
2
+
3
+ This folder contains code and recipes to reproduce results reported in a paper _Text-Free Prosody-Aware Generative Spoken Language Modeling_,
4
+ Eugene Kharitonov*, Ann Lee*, Adam Polyak, Yossi Adi, Jade Copet, Kushal Lakhotia, Tu-Anh Nguyen, Morgane Rivière, Abdelrahman Mohamed, Emmanuel Dupoux, Wei-Ning Hsu, 2021. arxiv/2109.03264 [[arxiv]](https://arxiv.org/abs/2109.03264).
5
+
6
+ `*` denotes equal contribution.
7
+
8
+ You can find demo samples [[here]](https://speechbot.github.io/pgslm/index.html).
9
+
10
+ <details>
11
+ <summary>If you find this code useful, please consider citing our work using this bibtex </summary>
12
+
13
+ ```
14
+ @misc{Kharitonov2021,
15
+ title={Text-Free Prosody-Aware Generative Spoken Language Modeling},
16
+ author={Eugene Kharitonov and Ann Lee and Adam Polyak and Yossi Adi and Jade Copet and Kushal Lakhotia and Tu-Anh Nguyen and Morgane Rivière and Abdelrahman Mohamed and Emmanuel Dupoux and Wei-Ning Hsu},
17
+ year={2021},
18
+ eprint={2109.03264},
19
+ archivePrefix={arXiv},
20
+ primaryClass={cs.CL}
21
+ }
22
+ ```
23
+ </details>
24
+
25
+
26
+ ## Additional requirements
27
+ Three packages are required in addition to fairseq, they are installable with pip:
28
+ ```bash
29
+ pip install AMFM-decompy SoundFile scipy sklearn torchaudio npy-append-array
30
+ ```
31
+
32
+ ## Data preprocessing
33
+
34
+ ### Prepare unit pseudo-text transcriptions of the audio
35
+ To get unit trascripts of the speech data we rely on the preprocessing steps of [GSLM](https://github.com/pytorch/fairseq/tree/main/examples/textless_nlp/gslm/speech2unit/) work.
36
+
37
+ Firstly, we will need to prepare manifest files for the dataset we want to preprocess
38
+ ```
39
+ mkdir manifests/
40
+ python examples/wav2vec/wav2vec_manifest.py --valid-percent=0.0 $DATA_PATH --dest=manifests/train/
41
+ ```
42
+ Next, we need a pre-trained HuBERT-base-ls960 model [[download]](https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt) and a corresponding kmeans-100 quantizer [[download]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/km100/km.bin). Having those we can quantize the dataset:
43
+ ```
44
+ python examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py \
45
+ --feature_type hubert \
46
+ --kmeans_model_path km.bin \
47
+ --acoustic_model_path hubert_base_ls960.pt \
48
+ --layer 6 \
49
+ --manifest_path manifests/train/train.tsv \
50
+ --out_quantized_file_path manifests/train/units
51
+ ```
52
+
53
+ Finally, by running
54
+ ```
55
+ python examples/textless_nlp/pgslm/scripts/join_units_manifest.py --manifest=manifests/train/train.tsv --units=manifests/train/units --output=train.txt
56
+ ```
57
+ We will get the training data description `train.txt` in the format that pGSLM expects. The above steps have to be repeated for
58
+ dev/test sets. Importantly, we rely on an assumption that the directories are structured as in LibriSpeech, i.e. the file paths follow the
59
+ `<spk_id>/<session_id>/<sample_id>.wav` format.
60
+
61
+ ### Preprocess data for pGSLM
62
+ The very first step is to obtain the F0 quantization bins.
63
+ Assume the vocoder training manifest is `vocoder_train.txt` (in pGSLM data format prepared with the same process above).
64
+ We prepare the quantized F0 from the vocoder training data by running
65
+ ```sh
66
+ bash examples/textless_nlp/pgslm/scripts/prepare_f0_quantization.sh \
67
+ vocoder_train.txt <sample_rate> 32 <preprocessed_dir> <output_prefix> # we use 32 bins in the paper
68
+ ```
69
+ - `<sample_rate>`: sampling rate of the audio files in the manifest
70
+ - `<preprocessed_dir>`: where to output the output files
71
+ - `<output_prefix>`: prefix of the output files
72
+
73
+ The script will generate
74
+ - `<output_prefix>.f0_stat.pt`: the speaker-level F0 statistics, which can be used in vocoder training
75
+ - `<output_prefix>_mean_norm_log_f0_bin.th`: the quantized F0, which should be used in `prepare_data.sh` below
76
+
77
+ **Note:** See "Pre-trained models" for the pre-computed speaker-level F0 statistics and quantized F0 bins. We suggest using the pre-computed statistics for the data preparation below in order to take advantage of the pre-trained vocoder for waveform generation.
78
+
79
+ Next prepare the pGSLM data.
80
+ Assume train/valid/test manifests are `{train,valid,test}.txt`.
81
+ Here is an example of how to preprocess data:
82
+
83
+ ```sh
84
+ bash examples/textless_nlp/pgslm/scripts/prepare_data.sh \
85
+ train.txt valid.txt test.txt <n_unit> <hop_size> <sample_rate> \
86
+ <preprocessed_dir>/<output_prefix>_mean_norm_log_f0_bin.th <preprocessed_dir>
87
+ ```
88
+ - `<n_unit>`: discrete unit vocabulary size (we used a kmeans quantizer with the number of units equal to 100 in the example above)
89
+ - `<hop_size>`: downsampling rate relative to the waveform (e.g., 320 for HuBERT units)
90
+ - `<sample_rate>`: sampling rate of the audio files in the manifest
91
+ - `<preprocessed_dir>`: where to output the preprocessed files
92
+
93
+ This will create the dataset json config used for the next section at
94
+ `<preprocessed_dir>/data_config.json`.
95
+
96
+ Note that the example script uses only one thread to compute F0, which can take
97
+ _very long_ for preprocessing large datasets. It is suggested to distribute
98
+ jobs over multiple nodes/processes with `--nshards=x` and `--rank=z` (where z is
99
+ in [1, x]) in `preprocess_f0.py`, and set `--nshards_list=x` in
100
+ `prepare_data.py` correspondingly to collect sharded F0 data.
101
+
102
+ Now, everything is ready for training a model.
103
+
104
+ ## Training Multi-Stream Transformer Unit Language Model (MS-TLM)
105
+
106
+ Below is an example command that trains Multi-Stream Transformer Language Model (MS-TLM) on a prepared dataset:
107
+ ```bash
108
+ DATASET=data_config.json
109
+
110
+ fairseq-train $DATASET \
111
+ --task=speech_unit_modeling \
112
+ --arch="transformer_ulm_tiny" \
113
+ --criterion=speech_unit_lm_criterion \
114
+ --share-decoder-input-output-embed \
115
+ --dropout=0.1 \
116
+ --attention-dropout=0.1 \
117
+ --optimizer="adam" \
118
+ --adam-betas="(0.9, 0.98)" \
119
+ --clip-norm=1.0 \
120
+ --lr=0.0005 \
121
+ --lr-scheduler="inverse_sqrt" \
122
+ --warmup-updates=4000 \
123
+ --warmup-init-lr=1e-07 \
124
+ --tokens-per-sample=3072 \
125
+ --max-tokens=3072 \
126
+ --update-freq=4 \
127
+ --max-epoch=70 \
128
+ --num-workers=0 \
129
+ --skip-invalid-size-inputs-valid-test \
130
+ --loss-weights="1.0;0.5;0.0" \
131
+ --ignore-f0-input \
132
+ --checkpoint-activations \
133
+ --fp16 \
134
+ --max-target-positions=4096 \
135
+ --stream-shifts="1,1" \
136
+ --log-f0 --normalize-f0-mean --interpolate-f0 \
137
+ --ignore-unused-valid-subsets \
138
+ --discrete-duration --discrete-f0
139
+ ```
140
+
141
+ Some of the important parameters that are specific to MS-TLM:
142
+ * `arch`: specifies the Transformer architecture used. Supported options are:
143
+ * `transformer_ulm_tiny` - a tiny model that can be used for debugging; it has 2 layers, 1 attention head, FFN and embedding dimensions of 64,
144
+ * `transformer_ulm` - a base model with 6 layers, 8 heads, embedding dimension 512, and FFN dimensionality of 2048,
145
+ * `transformer_ulm_big` - the largest model we experiment with in the paper: 12-layer/16 heads, 1024/4096 embedding and FFN dimensions;
146
+ * `loss-weights`: this parameter sets importance weights (must be non-negative) for the components of the loss that correspond to unit, duration, and F0 streams. To turn off a component of the loss, its weight has to be set to 0. For instance, to predict only unit stream the parameter should be set to "1;0;0";
147
+ * `stream-shifts`: specifies relative shifts of the two prosodic streams w.r.t. the unit stream (duration and F0, respectively). No shift corresponds to "0,0";
148
+ * `ignore-duration-input`/`ignore-f0-input`: setting these flags would zero-out correpsonding input streams;
149
+ * `max-token-duration`: duration values would be max-capped by the specified value;
150
+ * `discrete-duration`/`discrete-f0`: whether duration and F0 streams should be quantized;
151
+ * `log_f0`, `normalize-f0-mean`, `normalize-f0-std`, `interpolate-f0`: configure how F0 stream is treated. `log_f0` sets up modelling in the log-space, `normalize-f0-mean`/`normalize-f0-std` control per-speaker normalization, and `interpolate-f0` enables F0 interpolation for unvoiced regions where F0 was set to 0,
152
+ * `mask-dur-prob`, `mask-f0-prob`, `mask-dur-seg-prob`, `mask-f0-seg-prob`, `mask-unit-seg-prob`, `mask-unit-seg-leng`: this family of parameters sets the probababilities of masking individual steps and spans on each stream as well as lengths of the maked spans.
153
+
154
+
155
+ ## Pre-trained models
156
+ ### MS-TLM
157
+ Below you can find checkpoints for four best-performing models from the paper (IDs 9..12 in Table 1). These models are trained on Hubert-100 transcripts of the LibriLight-6K dataset. They have the prosody streams shifted by 1 w.r.t. the unit stream. All models predict all three streams (units, duration, and F0), but two
158
+ of them only have unit steam in their input.
159
+
160
+ | | Continuous prosody | Quantized prosody |
161
+ |-------------------|--------------------|-------------------|
162
+ | No prosody input | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/ulm_checkpoints/continuous_no_prosody_shift_1_1.pt) | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/ulm_checkpoints/discrete_no_prosody_shift_1_1.pt) |
163
+ | Has prosody input | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/ulm_checkpoints/continuous_prosody_shift_1_1.pt) | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/ulm_checkpoints/discrete_prosody_shift_1_1.pt)|
164
+
165
+ The optimal per-stream sampling temperatures/scaling parameters that we have identified for these models, in the (`T-token, T-duration, T-f0`) format:
166
+
167
+ | | Continuous prosody | Quantized prosody |
168
+ |-------------------|--------------------|-------------------|
169
+ | No prosody input | 0.7, 0.125, 0.0003125| 0.7, 0.25, 0.5 |
170
+ | Has prosody input | 0.7, 0.125, 0.00125 | 0.7, 0.25, 0.7 |
171
+
172
+ ## Vocoder
173
+ | Units | Prosody | F0 stats | Checkpoint | Config |
174
+ |-------------------|---------|--------------|------------|--------|
175
+ | HuBERT-base-ls960, kmeans-100 | [[Quantized 32 bins]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/vocoder/blizzard2013/mean_norm_log_f0_seg_bin.th) | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/vocoder/blizzard2013/f0_stats.pt) | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/vocoder/blizzard2013/naive_quant_32_norm_log_seg_hubert/checkpoint.pt) | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/vocoder/blizzard2013/naive_quant_32_norm_log_seg_hubert/config.json) |
176
+ | HuBERT-base-ls960, kmeans-100 | Continuous | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/vocoder/blizzard2013/f0_stats.pt) | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/vocoder/blizzard2013/mean_norm_log_f0_hubert/checkpoint.pt) | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/vocoder/blizzard2013/mean_norm_log_f0_hubert/config.json) |
177
+
178
+
179
+ ## Evaluating a trained model
180
+ Evaluation is done with the `eval/cont_metrics.py` scripts. As described in the paper, there are several metrics used.
181
+
182
+ **Teacher-forced metrics**
183
+ ```bash
184
+ SET=valid
185
+ CHECKPOINT_PATH=discrete_prosody_shift_1_1.pt
186
+ DATA=data_config.json
187
+
188
+ python examples/textless_nlp/pgslm/eval/cont_metrics.py $DATA \
189
+ --metric=teacher_force_everything \
190
+ --path=$CHECKPOINT_PATH \
191
+ --batch-size=16 \
192
+ --fp16 \
193
+ --seed=111 \
194
+ --eval-subset=$SET \
195
+ --f0-discretization-bounds=mean_norm_log_f0_seg_bin.th --dequantize-prosody
196
+ ```
197
+ (Using this command, our provided `discrete_prosody_shift_1_1.pt` checkpoint should produce `{'token_loss': 1.408..., 'duration_loss': 0.5424..., 'f0_loss': 0.0474...}` on LibriSpeech dev-clean).
198
+
199
+ The parameters `--f0-discretization-bounds=mean_norm_log_f0_seg_bin.th --dequantize-prosody` are specific for quantized-prosody models. They signal that the prosody streams must be decoded into the continuous domain before calculating correlation. It is the same `*_mean_norm_log_f0_bin.th` file as we prepared before.
200
+ The `mean_norm_log_f0_seg_bin.th` file we used with the pre-trained models can be downloaded [[here]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/vocoder/blizzard2013/mean_norm_log_f0_seg_bin.th).
201
+
202
+
203
+ **Consistency (aka Correlation) metrics**
204
+
205
+ The following command estimates correlation between mean values of the F0 stream in the prompt and in the generated continuation (unit and duration steams are fixed).
206
+
207
+ ```bash
208
+ T_F0=0.7
209
+ EXPLOSION=20
210
+ SET=test
211
+ CHECKPOINT_PATH=discrete_prosody_shift_1_1.pt
212
+ DATA=data_config.json
213
+
214
+ python examples/textless_nlp/pgslm/eval/cont_metrics.py $DATA \
215
+ --prefix-length=150 \
216
+ --metric=correlation \
217
+ --path=$CHECKPOINT_PATH \
218
+ --batch-size=16 \
219
+ --fp16 \
220
+ --seed=111 \
221
+ --teacher-force-tokens \
222
+ --teacher-force-duration \
223
+ --min-length=300 \
224
+ --batch-explosion-rate=$EXPLOSION \
225
+ --T-f0=$T_F0 \
226
+ --eval-subset=$SET \
227
+ --f0-discretization-bounds=mean_norm_log_f0_seg_bin.th \
228
+ --dequantize-prosody --n-workers=8
229
+ ```
230
+ (Using this command, our provided `discrete_prosody_shift_1_1.pt` checkpoint should produce `{...'F0 corr': 0.315 ..}` on LibriSpeech test-clean).
231
+
232
+ * By using flags `--teacher-force-tokens, --teacher-force-duration, --teacher-force-f0` one can calculate correlations along each stream while having other two streams fixed to ground-truth values (or freeze all three streams to get ground-truth correlation values);
233
+ * The parameters `T-f0`, `T-duration`, and `T-token` specify per-stream temperatures and, in the case of continuous-valued prosody, scaling parameter of the corresponding Laplace distribution (setting a temperature to 0 will enforce greedy sampling);
234
+ * `min-length` filters out sequences that are shorter then 300 duration units (i.e. 6s in the case of Hubert units);
235
+ * `prefix-length` specifies that we want to use first 150 duration units are prompt (i.e. 3s in the case of Hubert units)
236
+
237
+
238
+ **Correctness (aka Continuation) and Expressiveness (aka Std) metrics**
239
+
240
+ By running the following command, we can get minMAE and Std for the log-F0 stream for the model with quantized prosody.
241
+ ```bash
242
+ DATA=data_config.json
243
+ EXPLOSION=20
244
+ SET=test
245
+ CHECKPOINT_PATH=discrete_prosody_shift_1_1.pt
246
+ T_F0=0.7
247
+
248
+ python examples/textless_nlp/pgslm/eval/cont_metrics.py $DATA \
249
+ --prefix-length=150 \
250
+ --metric=continuation \
251
+ --path=$CHECKPOINT_PATH \
252
+ --batch-size=16 \
253
+ --fp16 \
254
+ --seed=111 \
255
+ --batch-explosion-rate=$EXPLOSION \
256
+ --teacher-force-tokens \
257
+ --teacher-force-duration \
258
+ --T-f0=$T_F0 \
259
+ --eval-subset=$SET \
260
+ --f0-discretization-bounds=mean_norm_log_f0_seg_bin.th --dequantize-prosody
261
+ ```
262
+ (Using this command, our provided `discrete_prosody_shift_1_1.pt` checkpoint should produce `{...'F0 MAE': 0.0772, 'F0 Std': 0.1489...}` on LibriSpeech test-clean).
263
+
264
+ Again, by setting `--teacher-force-tokens, --teacher-force-duration, --teacher-force-f0` we can calculate Token BLEU for the token stream (when `--teacher-force-duration` & `--teacher-force-f0` are on) and per-stream min MAE for each prosody stream individually.
265
+
266
+ Finally, `cont_metrics.py` allows to specify the number of workers (e.g., `n-workers=8`) which allows to speed up the computation by spreading multiple worker processes
267
+ over the available GPUs.
268
+
269
+ **Cont Word BLEU**
270
+
271
+ We used the code and the evaluation protocol of [(Lakhotia et al., 2021)](https://arxiv.org/abs/2102.01192).
272
+
273
+ ## Sampling from a trained model
274
+
275
+ To get (prompted or not) samples from a trained model it is enough to run `sample.py`:
276
+ ```bash
277
+ CHECKPOINT_PATH=checkpoints/checkpoint_best.pt
278
+ DATASET=examples/textless_nlp/pgslm/repro/dataset/data_config.json
279
+ python examples/textless_nlp/pgslm/sample/sample.py $DATASET \
280
+ --output=$SAMPLES \
281
+ --path=$CHECKPOINT_PATH \
282
+ --sampling \
283
+ --T-token=0.7 \
284
+ --T-duration=0.25 \
285
+ --T-f0=0.7 \
286
+ --max-length=500 \
287
+ --prefix-length=150 \
288
+ --subset=valid \
289
+ --seed=1 \
290
+ --match-duration \
291
+ --code-type=hubert \
292
+ --batch-explosion-rate=2
293
+ ```
294
+
295
+ Some useful parameters:
296
+ * `T-token`, `T-duration`, `T-f0` specify sampling temperature for the three streams. Setting a temperature to `0` switches sample to the greedy (argmax) one;
297
+ * `prefix-length`: length of the prompt, measured in timesteps (e.g. for Hubert (CPC) each timestep is 20 (10) ms);
298
+ * `subset`: which subset of the dataset to use as prompts (can be `train`, `valid`, `test`);
299
+ * `teacher-force-tokens`, `teacher-force-duration`, `teacher-force-f0`: if set, at each autoregressive step, ground-truth values replace the produced one;
300
+ * `short-curcuit`: replace sampling by ground-truth inputs;
301
+ * `match-duration`: forces the produced sample to have the same duration (in time), as the entire sequence (beyond the prompt if there is any);
302
+ * `batch-explosion-rate`: number of samples per prompt;
303
+ * `f0-discretization-bounds`: path to a file with quantization boundaries. If it is set, F0 values are de-quantized back to the continuous domain
304
+ (the model must be a quanized one);
305
+ * `max-length` sets the maximal number of segment steps to be produced.
306
+
307
+ Note that `sample.py` automatically uses all available GPUs, to avoid that please use environment variable `CUDA_VISIBLE_DEVICES`.
308
+
309
+ ## Vocoding samples
310
+ To generate audios for output from `sample.py` (`$IN_FILE`):
311
+ ```bash
312
+ python examples/textless_nlp/pgslm/generate_waveform.py \
313
+ --in-file=$IN_FILE \
314
+ --vocoder=$VODOER \
315
+ --vocoder-cfg=$VOCODER_CFG \
316
+ --results-path=$RESULTS_PATH
317
+ ```
318
+ See "Pre-trained model" for `$VOCODER` and `VOCODER_CFG`.
fairseq/examples/textless_nlp/pgslm/data_utils.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import torch
8
+
9
+ from tqdm import tqdm
10
+
11
+
12
+ class Stat:
13
+ def __init__(self, keep_raw=False):
14
+ self.x = 0.0
15
+ self.x2 = 0.0
16
+ self.z = 0.0 # z = logx
17
+ self.z2 = 0.0
18
+ self.n = 0.0
19
+ self.u = 0.0
20
+ self.keep_raw = keep_raw
21
+ self.raw = []
22
+
23
+ def update(self, new_x):
24
+ new_z = new_x.log()
25
+
26
+ self.x += new_x.sum()
27
+ self.x2 += (new_x**2).sum()
28
+ self.z += new_z.sum()
29
+ self.z2 += (new_z**2).sum()
30
+ self.n += len(new_x)
31
+ self.u += 1
32
+
33
+ if self.keep_raw:
34
+ self.raw.append(new_x)
35
+
36
+ @property
37
+ def mean(self):
38
+ return self.x / self.n
39
+
40
+ @property
41
+ def std(self):
42
+ return (self.x2 / self.n - self.mean**2) ** 0.5
43
+
44
+ @property
45
+ def mean_log(self):
46
+ return self.z / self.n
47
+
48
+ @property
49
+ def std_log(self):
50
+ return (self.z2 / self.n - self.mean_log**2) ** 0.5
51
+
52
+ @property
53
+ def n_frms(self):
54
+ return self.n
55
+
56
+ @property
57
+ def n_utts(self):
58
+ return self.u
59
+
60
+ @property
61
+ def raw_data(self):
62
+ assert self.keep_raw, "does not support storing raw data!"
63
+ return torch.cat(self.raw)
64
+
65
+
66
+ class F0Stat(Stat):
67
+ def update(self, new_x):
68
+ # assume unvoiced frames are 0 and consider only voiced frames
69
+ if new_x is not None:
70
+ super().update(new_x[new_x != 0])
71
+
72
+
73
+ def dump_speaker_f0_stat(speaker_to_f0_stat, out_prefix):
74
+ path = f"{out_prefix}.f0_stat.pt"
75
+ assert not os.path.exists(path)
76
+
77
+ d = {
78
+ speaker: {
79
+ "f0_mean": speaker_to_f0_stat[speaker].mean,
80
+ "f0_std": speaker_to_f0_stat[speaker].std,
81
+ "logf0_mean": speaker_to_f0_stat[speaker].mean_log,
82
+ "logf0_std": speaker_to_f0_stat[speaker].std_log,
83
+ }
84
+ for speaker in speaker_to_f0_stat
85
+ }
86
+ torch.save(d, path)
87
+
88
+ return d
89
+
90
+
91
+ def load_audio_path(path):
92
+ audio_paths = []
93
+ with open(path) as f:
94
+ for line in f.readlines():
95
+ sample = eval(line.strip())
96
+ audio_paths.append(sample["audio"])
97
+
98
+ return audio_paths
99
+
100
+
101
+ def load_f0(f0_dir, nshards):
102
+ path_to_f0 = {}
103
+ for rank in tqdm(range(1, nshards + 1), desc=f"load f0"):
104
+ f0_shard_path = f"{f0_dir}/f0_{rank}_{nshards}.pt"
105
+ shard_path_to_f0 = torch.load(f0_shard_path)
106
+ path_to_f0.update(shard_path_to_f0)
107
+ return path_to_f0
fairseq/examples/textless_nlp/pgslm/eval/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
fairseq/examples/textless_nlp/pgslm/eval/cont_metrics.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import numpy as np
8
+ import scipy
9
+
10
+ import torch
11
+ import torch.multiprocessing as mp
12
+ from fairseq import checkpoint_utils, options
13
+ from fairseq.data.codedataset import CodeDataset, ExpressiveCodeDataConfig
14
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
15
+ from torch.utils.data import DataLoader, DistributedSampler
16
+ from fairseq.utils import move_to_cuda
17
+ from fairseq import utils
18
+ from fairseq.criterions.speech_ulm_criterion import nll_loss, mae_loss
19
+
20
+ import time
21
+ from types import SimpleNamespace
22
+
23
+ import sys, pathlib
24
+
25
+ sys.path.append(str(pathlib.Path(__file__).parent.parent.resolve()))
26
+
27
+ from naive_decoder import Naive_F0_Decoder
28
+ from inference_dataset import InferenceDataset, explode_batch
29
+ from sample.sample import do_sampling, TemperatureDecoder, FilterNamesDataset
30
+
31
+ try:
32
+ from nltk.translate.bleu_score import sentence_bleu
33
+ except ImportError:
34
+ print("Please install nltk: `pip install --user -U nltk`")
35
+ raise
36
+
37
+
38
+ @torch.no_grad()
39
+ def teacher_force_everything(
40
+ args, dataset, model, criterion, tgt_dict, rank, world_size
41
+ ):
42
+ prefix = args.prefix_length
43
+
44
+ f0_decoder = None
45
+ if args.dequantize_prosody:
46
+ assert dataset.discrete_f0
47
+ print("Reporting MAE for a discrete model")
48
+ f0_decoder = Naive_F0_Decoder(
49
+ args.f0_discretization_bounds, dataset.config.f0_vq_n_units
50
+ ).cuda()
51
+
52
+ dataset = InferenceDataset(
53
+ dataset,
54
+ prefix=args.prefix_length,
55
+ only_prefix=False,
56
+ filter_short=True,
57
+ presort_by_length=True,
58
+ )
59
+ sampler = (
60
+ None
61
+ if world_size == 1
62
+ else DistributedSampler(
63
+ dataset, num_replicas=world_size, rank=rank, shuffle=False
64
+ )
65
+ )
66
+ dataloader = DataLoader(
67
+ dataset,
68
+ args.batch_size,
69
+ shuffle=False,
70
+ collate_fn=dataset.collater,
71
+ sampler=sampler,
72
+ )
73
+
74
+ total_token_loss, total_duration_loss, total_f0_loss, total_tokens = (
75
+ 0.0,
76
+ 0.0,
77
+ 0.0,
78
+ 0.0,
79
+ )
80
+
81
+ i = 0
82
+ for batch in dataloader:
83
+ i += 1
84
+ batch = move_to_cuda(batch)
85
+ output = model(**batch["net_input"])
86
+
87
+ tokens, durations, f0 = output["token"], output["duration"], output["f0"]
88
+ durations, f0 = durations.squeeze(), f0.squeeze()
89
+
90
+ token_loss = nll_loss(
91
+ tokens[:, prefix - 1 :],
92
+ batch["target"][:, prefix - 1 :].contiguous(),
93
+ batch["mask"][:, prefix - 1 :].contiguous(),
94
+ reduce=True,
95
+ )
96
+
97
+ if args.dequantize_prosody:
98
+ durations = durations.argmax(dim=-1)
99
+ duration_loss = mae_loss(
100
+ durations[:, prefix - 1 :].contiguous().float(),
101
+ batch["dur_target"][:, prefix - 1 :].contiguous().float(),
102
+ batch["dur_mask"][:, prefix - 1 :].contiguous(),
103
+ reduce=True,
104
+ )
105
+ else:
106
+ duration_loss = criterion.dur_loss_fn(
107
+ durations[:, prefix - 1 :].contiguous(),
108
+ batch["dur_target"][:, prefix - 1 :].contiguous(),
109
+ batch["dur_mask"][:, prefix - 1 :].contiguous(),
110
+ reduce=True,
111
+ )
112
+
113
+ if f0_decoder:
114
+ f0 = f0.argmax(dim=-1)
115
+ f0 = f0_decoder(f0).squeeze(-1)
116
+
117
+ f0_target = batch["raw_f0"]
118
+ f0_loss = mae_loss(
119
+ f0[:, prefix - 1 :].contiguous(),
120
+ f0_target[:, prefix - 1 :].contiguous(),
121
+ batch["f0_mask"][:, prefix - 1 :].contiguous(),
122
+ reduce=True,
123
+ )
124
+ else:
125
+ f0_loss = criterion.f0_loss_fn(
126
+ f0[:, prefix - 1 :].contiguous(),
127
+ batch["f0_target"][:, prefix - 1 :].contiguous(),
128
+ batch["f0_mask"][:, prefix - 1 :].contiguous(),
129
+ reduce=True,
130
+ )
131
+
132
+ n_tokens = (~batch["dur_mask"])[:, prefix - 1 :].sum()
133
+
134
+ total_token_loss += token_loss.item()
135
+ total_duration_loss += duration_loss.item()
136
+ total_f0_loss += f0_loss.item()
137
+
138
+ total_tokens += n_tokens.item()
139
+ if args.debug and i > 5:
140
+ break
141
+
142
+ values = torch.tensor([total_token_loss, total_duration_loss, total_f0_loss])
143
+ normalizers = torch.tensor([total_tokens for _ in range(3)])
144
+
145
+ return values, normalizers
146
+
147
+
148
+ def get_bleu(produced_tokens, target_tokens, tgt_dict):
149
+ assert target_tokens.ndim == 1
150
+ assert produced_tokens.size(1) == target_tokens.size(0)
151
+
152
+ # we can have padding due to shifted channels
153
+ shift = 0
154
+ for token in reversed(target_tokens.cpu().tolist()):
155
+ if token in [tgt_dict.pad(), tgt_dict.eos()]:
156
+ shift += 1
157
+ else:
158
+ break
159
+ target_tokens = target_tokens[:-shift]
160
+ produced_tokens = produced_tokens[:, :-shift]
161
+
162
+ string_target = tgt_dict.string(target_tokens).split()
163
+ string_candidates = [
164
+ tgt_dict.string(produced_tokens[i, :]).split()
165
+ for i in range(produced_tokens.size(0))
166
+ ]
167
+
168
+ bleu3 = sentence_bleu(
169
+ references=string_candidates,
170
+ hypothesis=string_target,
171
+ weights=(1.0 / 3, 1.0 / 3, 1.0 / 3),
172
+ )
173
+ return bleu3
174
+
175
+
176
+ @torch.no_grad()
177
+ def continuation(args, dataset, model, criterion, tgt_dict, rank, world_size):
178
+ is_discrete_duration = dataset.discrete_dur
179
+ is_discrete_f0 = dataset.discrete_f0
180
+
181
+ f0_decoder = None
182
+ if args.dequantize_prosody:
183
+ assert dataset.discrete_f0
184
+ print("Reporting MAE F0 for a discrete model")
185
+ f0_decoder = Naive_F0_Decoder(
186
+ args.f0_discretization_bounds, dataset.config.f0_vq_n_units
187
+ ).cuda()
188
+
189
+ dataset = InferenceDataset(
190
+ dataset, args.prefix_length, filter_short=True, presort_by_length=True
191
+ )
192
+ sampler = (
193
+ None
194
+ if world_size == 1
195
+ else DistributedSampler(
196
+ dataset, num_replicas=world_size, rank=rank, shuffle=False
197
+ )
198
+ )
199
+ dataloader = DataLoader(
200
+ dataset,
201
+ batch_size=1,
202
+ shuffle=False,
203
+ collate_fn=dataset.collater,
204
+ sampler=sampler,
205
+ )
206
+
207
+ Ts = args.T_token, args.T_duration, args.T_f0
208
+ decoder = TemperatureDecoder(
209
+ Ts, discrete_dur=is_discrete_duration, discrete_f0=is_discrete_f0
210
+ )
211
+
212
+ running_stats = SimpleNamespace(
213
+ token_bleu=0.0,
214
+ duration_nll=0.0,
215
+ duration_mae=0.0,
216
+ f0_nll=0.0,
217
+ f0_mae=0.0,
218
+ n_tokens=0.0,
219
+ n_sentences=0.0,
220
+ f0_sum=0.0,
221
+ f0_sum_sq=0.0,
222
+ dur_sum=0.0,
223
+ dur_sum_sq=0.0,
224
+ )
225
+
226
+ for i, batch in enumerate(dataloader):
227
+ batch = explode_batch(batch, args.batch_explosion_rate)
228
+ bsz = batch["target"].size(0)
229
+
230
+ batch = move_to_cuda(batch)
231
+ prefix = batch["prefix"][0]
232
+
233
+ max_length_to_unroll = batch["target"].size(1)
234
+ prefix_length = batch["net_input"]["src_tokens"].size(1)
235
+ steps = max_length_to_unroll - prefix_length + 1
236
+
237
+ assert steps > 0
238
+ produced_tokens, produced_durations, produced_f0, outputs = do_sampling(
239
+ model,
240
+ batch,
241
+ tgt_dict.eos(),
242
+ decoder,
243
+ autoregressive_steps=steps,
244
+ teacher_force_tokens=args.teacher_force_tokens,
245
+ teacher_force_duration=args.teacher_force_duration,
246
+ teacher_force_f0=args.teacher_force_f0,
247
+ )
248
+
249
+ if args.teacher_force_tokens:
250
+ assert (produced_tokens[:, 1:] == batch["target"]).all()
251
+ if args.teacher_force_duration:
252
+ assert (produced_durations[:, 1:] == batch["dur_target"]).all()
253
+ if args.teacher_force_f0:
254
+ assert (produced_f0[:, 1:] == batch["f0_target"]).all()
255
+
256
+ dur_target = batch["dur_target"][:, prefix - 1 :].contiguous()
257
+ f0_target = batch["f0_target"][:, prefix - 1 :].contiguous()
258
+
259
+ f0_mask = batch["f0_mask"][:, prefix - 1 :].contiguous()
260
+ dur_mask = batch["dur_mask"][:, prefix - 1 :].contiguous()
261
+
262
+ duration_mae = mae_loss(
263
+ produced_durations[:, prefix:].float(),
264
+ dur_target.float(),
265
+ dur_mask,
266
+ reduce=False,
267
+ )
268
+ min_duration_mae = duration_mae.view(bsz, -1).sum(dim=-1).min(dim=0)[0]
269
+ running_stats.duration_mae += min_duration_mae
270
+
271
+ running_stats.dur_sum += (
272
+ produced_durations[:, prefix:].float() * (~dur_mask)
273
+ ).sum() / args.batch_explosion_rate
274
+ running_stats.dur_sum_sq += (
275
+ produced_durations[:, prefix:].float() * (~dur_mask)
276
+ ).pow(2.0).sum() / args.batch_explosion_rate
277
+
278
+ if is_discrete_duration:
279
+ duration_loss = criterion.dur_loss_fn(
280
+ torch.stack([x[1] for x in outputs], dim=1),
281
+ dur_target,
282
+ dur_mask,
283
+ reduce=False,
284
+ )
285
+ min_duration_loss = duration_loss.view(bsz, -1).sum(dim=-1).min(dim=0)[0]
286
+ running_stats.duration_nll += min_duration_loss
287
+
288
+ if f0_decoder: # can only exist for discrete F0 models
289
+ decoded_produced_f0 = f0_decoder(produced_f0[:, prefix:])
290
+ decoded_f0_target = batch["raw_f0"][:, prefix - 1 :].contiguous()
291
+
292
+ if produced_f0.ndim == 3:
293
+ decoded_produced_f0 = decoded_produced_f0.squeeze(2)
294
+ decoded_f0_target = decoded_f0_target.squeeze(2)
295
+
296
+ f0_mae = mae_loss(
297
+ decoded_produced_f0, decoded_f0_target, f0_mask, reduce=False
298
+ )
299
+ f0_mae = f0_mae.view(bsz, -1).sum(dim=-1).min(dim=0)[0]
300
+ running_stats.f0_mae += f0_mae
301
+
302
+ f0_loss = criterion.f0_loss_fn(
303
+ torch.stack([x[2] for x in outputs], dim=1),
304
+ f0_target.long(),
305
+ f0_mask,
306
+ reduce=False,
307
+ )
308
+ f0_loss = f0_loss.view(bsz, -1).sum(dim=-1).min(dim=0)[0]
309
+ running_stats.f0_nll += f0_loss
310
+
311
+ running_stats.f0_sum += (
312
+ decoded_produced_f0 * (~f0_mask)
313
+ ).sum() / args.batch_explosion_rate
314
+ running_stats.f0_sum_sq += (decoded_produced_f0 * (~f0_mask)).pow(
315
+ 2.0
316
+ ).sum() / args.batch_explosion_rate
317
+
318
+ else:
319
+ assert not is_discrete_duration
320
+
321
+ f0_loss = mae_loss(
322
+ produced_f0[:, prefix:], f0_target, f0_mask, reduce=False
323
+ )
324
+ f0_loss = f0_loss.view(bsz, -1).sum(dim=-1).min(dim=0)[0]
325
+ running_stats.f0_mae += f0_loss
326
+
327
+ running_stats.f0_sum += (
328
+ produced_f0[:, prefix:].sum() / args.batch_explosion_rate
329
+ )
330
+ running_stats.f0_sum_sq += (
331
+ produced_f0[:, prefix:].pow(2.0).sum() / args.batch_explosion_rate
332
+ )
333
+
334
+ running_stats.n_tokens += (~dur_mask)[0, ...].sum()
335
+
336
+ token_loss = get_bleu(
337
+ produced_tokens[:, prefix:], batch["target"][0, prefix - 1 :], tgt_dict
338
+ )
339
+ running_stats.token_bleu += token_loss
340
+ running_stats.n_sentences += 1
341
+
342
+ if args.debug:
343
+ break
344
+
345
+ values = torch.tensor(
346
+ [
347
+ running_stats.token_bleu,
348
+ running_stats.duration_nll,
349
+ running_stats.duration_mae,
350
+ running_stats.f0_nll,
351
+ running_stats.f0_mae,
352
+ running_stats.f0_sum,
353
+ running_stats.f0_sum_sq,
354
+ running_stats.dur_sum,
355
+ running_stats.dur_sum_sq,
356
+ ]
357
+ )
358
+ normalizers = torch.tensor(
359
+ [running_stats.n_sentences] + [running_stats.n_tokens] * 8
360
+ )
361
+
362
+ return values, normalizers
363
+
364
+
365
+ @torch.no_grad()
366
+ def correlation(args, dataset, model, criterion, tgt_dict, rank, world_size):
367
+ is_discrete_duration = dataset.discrete_dur
368
+ is_discrete_f0 = dataset.discrete_f0
369
+
370
+ f0_decoder = None
371
+ if is_discrete_f0:
372
+ assert dataset.discrete_f0
373
+ f0_decoder = Naive_F0_Decoder(
374
+ args.f0_discretization_bounds, dataset.config.f0_vq_n_units
375
+ ).cuda()
376
+
377
+ if is_discrete_f0:
378
+ assert f0_decoder # correlation on tokens is meaningless
379
+
380
+ dataset = InferenceDataset(
381
+ dataset,
382
+ args.prefix_length,
383
+ filter_short=True,
384
+ presort_by_length=True,
385
+ min_length=args.min_length,
386
+ )
387
+ sampler = (
388
+ None
389
+ if world_size == 1
390
+ else DistributedSampler(
391
+ dataset, num_replicas=world_size, rank=rank, shuffle=False
392
+ )
393
+ )
394
+ dataloader = DataLoader(
395
+ dataset,
396
+ batch_size=1,
397
+ shuffle=False,
398
+ collate_fn=dataset.collater,
399
+ sampler=sampler,
400
+ )
401
+
402
+ Ts = args.T_token, args.T_duration, args.T_f0
403
+ decoder = TemperatureDecoder(
404
+ Ts, discrete_dur=is_discrete_duration, discrete_f0=is_discrete_f0
405
+ )
406
+
407
+ mean_dur_prefix, mean_dur_cont = [], []
408
+ mean_f0_prefix, mean_f0_cont = [], []
409
+
410
+ for batch in dataloader:
411
+ batch = explode_batch(batch, args.batch_explosion_rate)
412
+ batch = move_to_cuda(batch)
413
+
414
+ assert len(batch["prefix"]) == 1
415
+
416
+ if args.teacher_force_tokens:
417
+ autoregressive_steps = batch["target"].size(1) - args.prefix_length - 1
418
+ else:
419
+ autoregressive_steps = args.max_length - args.prefix_length # + max_shift?
420
+
421
+ if args.copy_target:
422
+ produced_durations, produced_f0 = batch["dur_target"], batch["f0_target"]
423
+ else:
424
+ _, produced_durations, produced_f0, outputs = do_sampling(
425
+ model,
426
+ batch,
427
+ tgt_dict.eos(),
428
+ decoder,
429
+ autoregressive_steps=autoregressive_steps,
430
+ teacher_force_tokens=args.teacher_force_tokens,
431
+ teacher_force_duration=args.teacher_force_duration,
432
+ teacher_force_f0=args.teacher_force_f0,
433
+ )
434
+
435
+ # first tokens actually correspond to BOS
436
+ produced_durations = produced_durations[:, 1:]
437
+ produced_f0 = produced_f0[:, 1:]
438
+
439
+ dur_target = batch["dur_target"]
440
+ if is_discrete_duration:
441
+ produced_durations = produced_durations.float()
442
+ dur_target = dur_target.float()
443
+
444
+ if is_discrete_f0:
445
+ produced_f0 = f0_decoder(produced_f0).squeeze(-1)
446
+ f0_target = batch["raw_f0"]
447
+ else:
448
+ f0_target = batch["f0_target"]
449
+
450
+ # prefix values
451
+ prefix = batch["prefix"][0]
452
+ dur_prefix_mean = dur_target[:, :prefix].sum(dim=-1) / (
453
+ (~batch["dur_mask"][:, :prefix]).sum(dim=-1)
454
+ )
455
+
456
+ non_voiced = f0_target[:, :prefix] == 0.0
457
+ f0_mask = batch["f0_mask"][:, :prefix].logical_or(non_voiced)
458
+ f0_prefix_mean = f0_target[:, :prefix].sum(dim=-1) / ((~f0_mask).sum(dim=-1))
459
+
460
+ # continuation values
461
+ dur_cont_mean = produced_durations[:, prefix:].sum(dim=-1) / (
462
+ (~batch["dur_mask"][:, prefix:]).sum(dim=-1)
463
+ )
464
+
465
+ non_voiced = produced_f0[:, prefix:] == 0.0
466
+ f0_mask = non_voiced
467
+ f0_cont_mean = produced_f0[:, prefix:].sum(dim=-1) / ((~f0_mask).sum(dim=-1))
468
+
469
+ assert not f0_cont_mean.isnan().any()
470
+
471
+ mean_dur_prefix.append(dur_prefix_mean.cpu())
472
+ mean_dur_cont.append(dur_cont_mean.cpu())
473
+
474
+ mean_f0_prefix.append(f0_prefix_mean.cpu())
475
+ mean_f0_cont.append(f0_cont_mean.cpu())
476
+
477
+ if args.debug and len(mean_dur_prefix) > 10:
478
+ break
479
+
480
+ mean_dur_prefix, mean_dur_cont = torch.cat(mean_dur_prefix), torch.cat(
481
+ mean_dur_cont
482
+ )
483
+ mean_f0_prefix, mean_f0_cont = torch.cat(mean_f0_prefix), torch.cat(mean_f0_cont)
484
+
485
+ return mean_dur_prefix, mean_dur_cont, mean_f0_prefix, mean_f0_cont
486
+
487
+
488
+ def main(rank, world_size, args):
489
+ start = time.time()
490
+
491
+ if world_size > 1:
492
+ torch.distributed.init_process_group(
493
+ backend="gloo", init_method="env://", world_size=world_size, rank=rank
494
+ )
495
+ torch.cuda.set_device(rank % torch.cuda.device_count())
496
+
497
+ raw_args = args
498
+
499
+ args = convert_namespace_to_omegaconf(args)
500
+ if args.common.seed is not None:
501
+ np.random.seed(args.common.seed)
502
+ utils.set_torch_seed(args.common.seed)
503
+
504
+ models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
505
+ [raw_args.path], arg_overrides={"data": args.task.data}
506
+ )
507
+
508
+ tgt_dict = task.target_dictionary
509
+
510
+ for model in models:
511
+ model.prepare_for_inference_(args)
512
+ model.cuda().eval()
513
+ if raw_args.fp16:
514
+ model = model.half()
515
+ model = models[0]
516
+
517
+ config = ExpressiveCodeDataConfig(args.task.data)
518
+
519
+ dataset = CodeDataset(
520
+ manifest=config.manifests[raw_args.eval_subset],
521
+ dictionary=task.source_dictionary,
522
+ dur_dictionary=task.source_duration_dictionary,
523
+ f0_dictionary=task.source_f0_dictionary,
524
+ config=config,
525
+ discrete_dur=task.cfg.discrete_duration,
526
+ discrete_f0=task.cfg.discrete_f0,
527
+ log_f0=task.cfg.log_f0,
528
+ normalize_f0_mean=task.cfg.normalize_f0_mean,
529
+ normalize_f0_std=task.cfg.normalize_f0_std,
530
+ interpolate_f0=task.cfg.interpolate_f0,
531
+ shifts=task.cfg.stream_shifts,
532
+ return_filename=True,
533
+ strip_filename=False,
534
+ return_continuous_f0=raw_args.dequantize_prosody,
535
+ )
536
+
537
+ if raw_args.filter_names:
538
+ dataset = FilterNamesDataset(dataset, raw_args.filter_names)
539
+
540
+ criterion = task.build_criterion(model_args.criterion)
541
+
542
+ name2metric = {
543
+ "continuation": continuation,
544
+ "teacher_force_everything": teacher_force_everything,
545
+ "correlation": correlation,
546
+ }
547
+
548
+ name2keys = {
549
+ "continuation": (
550
+ "Token BLEU3",
551
+ "Duration NLL",
552
+ "Duration MAE",
553
+ "F0 NLL",
554
+ "F0 MAE",
555
+ "F0 sum",
556
+ "F0 sum_sq",
557
+ "Dur sum",
558
+ "Dur sum_sq",
559
+ ),
560
+ "teacher_force_everything": ("token_loss", "duration_loss", "f0_loss"),
561
+ "correlation": ("Duration corr", "F0 corr"),
562
+ }
563
+ metric_name = raw_args.metric
564
+
565
+ metric = name2metric[metric_name]
566
+ results = metric(raw_args, dataset, model, criterion, tgt_dict, rank, world_size)
567
+
568
+ values = None
569
+
570
+ if metric_name not in [
571
+ "correlation",
572
+ ]:
573
+ values, normalizers = results
574
+ values = maybe_aggregate_normalize(values, normalizers, world_size)
575
+ elif metric_name == "correlation":
576
+ values = maybe_aggregate_correlations(results, world_size)
577
+ else:
578
+ assert False
579
+
580
+ assert values is not None
581
+ summary = dict(zip(name2keys[raw_args.metric], values.tolist()))
582
+ if metric_name == "continuation":
583
+ summary["F0 Std"] = np.sqrt(-summary["F0 sum"] ** 2 + summary["F0 sum_sq"])
584
+ summary["Dur Std"] = np.sqrt(-summary["Dur sum"] ** 2 + summary["Dur sum_sq"])
585
+ del summary["F0 sum"]
586
+ del summary["F0 sum_sq"]
587
+ del summary["Dur sum"]
588
+ del summary["Dur sum_sq"]
589
+
590
+ summary["metric"] = metric_name
591
+
592
+ if rank == 0:
593
+ print(summary)
594
+ if raw_args.wandb:
595
+ wandb_results(summary, raw_args)
596
+ print("# finished in ", time.time() - start, "seconds")
597
+
598
+
599
+ def wandb_results(summary, raw_args):
600
+ import wandb
601
+
602
+ run = wandb.init(
603
+ project=raw_args.wandb_project_name, tags=raw_args.wandb_tags.split(",")
604
+ )
605
+ run.config.metric = raw_args.metric
606
+ run.config.model = raw_args.path
607
+ run.config.data = raw_args.data
608
+
609
+ if raw_args.wandb_run_name:
610
+ run.name = raw_args.wandb_run_name
611
+ run.save()
612
+
613
+ wandb.log(summary)
614
+ wandb.finish()
615
+
616
+
617
+ def maybe_aggregate_normalize(values, normalizers, world_size):
618
+ if world_size > 1:
619
+ torch.distributed.barrier()
620
+
621
+ torch.distributed.all_reduce_multigpu([values])
622
+ torch.distributed.all_reduce_multigpu([normalizers])
623
+
624
+ return values / normalizers
625
+
626
+
627
+ def maybe_aggregate_correlations(results, world_size):
628
+ if world_size > 1:
629
+ output = [None for _ in range(world_size)]
630
+ torch.distributed.all_gather_object(output, results)
631
+ mean_dur_prefix, mean_dur_cont, mean_f0_prefix, mean_f0_cont = [
632
+ torch.cat([x[i] for x in output]) for i in range(4)
633
+ ]
634
+ else:
635
+ mean_dur_prefix, mean_dur_cont, mean_f0_prefix, mean_f0_cont = results
636
+
637
+ corr_dur = scipy.stats.pearsonr(mean_dur_prefix.numpy(), mean_dur_cont.numpy())[0]
638
+ corr_f0 = scipy.stats.pearsonr(mean_f0_prefix.numpy(), mean_f0_cont.numpy())[0]
639
+ values = torch.tensor([corr_dur, corr_f0])
640
+
641
+ return values
642
+
643
+
644
+ def cli_main():
645
+ parser = options.get_interactive_generation_parser()
646
+ parser.add_argument(
647
+ "--prefix-length",
648
+ type=int,
649
+ default=1,
650
+ help="Prompt prefix length (including <s>)",
651
+ )
652
+ parser.add_argument(
653
+ "--duration-scale",
654
+ type=float,
655
+ default=1,
656
+ help="Multiply durations by the given scaler",
657
+ )
658
+ parser.add_argument(
659
+ "--debug", action="store_true", help="Process only the first batch"
660
+ )
661
+ parser.add_argument("--n_hypotheses", type=int, default=1)
662
+ parser.add_argument("--filter-names", type=str, default=None)
663
+ parser.add_argument(
664
+ "--max-length", type=int, default=200, help="Maximal produced length"
665
+ )
666
+
667
+ parser.add_argument("--teacher-force-tokens", action="store_true", default=False)
668
+ parser.add_argument("--teacher-force-duration", action="store_true", default=False)
669
+ parser.add_argument("--teacher-force-f0", action="store_true", default=False)
670
+
671
+ parser.add_argument("--copy-target", action="store_true", default=False)
672
+ parser.add_argument("--min-length", type=int, default=None)
673
+ parser.add_argument("--f0-discretization-bounds", type=str, default=None)
674
+ parser.add_argument("--dequantize-prosody", action="store_true")
675
+ parser.add_argument("--batch-explosion-rate", type=int, default=1)
676
+
677
+ parser.add_argument(
678
+ "--metric",
679
+ choices=["continuation", "teacher_force_everything", "correlation"],
680
+ required=True,
681
+ )
682
+
683
+ parser.add_argument("--wandb", action="store_true")
684
+ parser.add_argument("--wandb-project-name", type=str, default="eslm")
685
+ parser.add_argument("--wandb-tags", type=str, default="")
686
+ parser.add_argument("--wandb-run-name", type=str, default="")
687
+
688
+ parser.add_argument("--T-token", type=float, default=1.0)
689
+ parser.add_argument("--T-duration", type=float, default=1.0)
690
+ parser.add_argument("--T-f0", type=float, default=1.0)
691
+
692
+ parser.add_argument("--n-workers", type=int, default=1)
693
+
694
+ parser.add_argument(
695
+ "--eval-subset", type=str, default="valid", choices=["valid", "test"]
696
+ )
697
+
698
+ args = options.parse_args_and_arch(parser)
699
+
700
+ assert (
701
+ args.prefix_length >= 1
702
+ ), "Prefix length includes bos token <s>, hence the minimum is 1."
703
+ assert args.temperature >= 0.0, "T must be non-negative!"
704
+
705
+ if args.dequantize_prosody:
706
+ assert args.f0_discretization_bounds
707
+
708
+ world_size = args.n_workers or torch.cuda.device_count()
709
+ if world_size > 1:
710
+ import random
711
+
712
+ mp.set_start_method("spawn", force=True)
713
+ os.environ["MASTER_ADDR"] = "localhost"
714
+ os.environ["MASTER_PORT"] = str(random.randint(10_000, 50_000))
715
+
716
+ mp.spawn(
717
+ main,
718
+ nprocs=world_size,
719
+ args=(
720
+ world_size,
721
+ args,
722
+ ),
723
+ join=True,
724
+ )
725
+ else:
726
+ main(rank=0, world_size=world_size, args=args)
727
+
728
+
729
+ if __name__ == "__main__":
730
+ cli_main()
fairseq/examples/textless_nlp/pgslm/generate_waveform.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import ast
7
+ import argparse
8
+ import json
9
+ import logging
10
+ from pathlib import Path
11
+ import soundfile as sf
12
+ import torch
13
+
14
+ from tqdm import tqdm
15
+
16
+ from fairseq import utils
17
+ from fairseq.models.text_to_speech.vocoder import CodeHiFiGANVocoder
18
+
19
+
20
+ logging.basicConfig()
21
+ logging.root.setLevel(logging.INFO)
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ def dump_result(args, data, sample_id, pred_wav):
27
+ assert "audio" in data or args.results_path is not None
28
+ if args.results_path:
29
+ fname = Path(data["audio"]).name if "audio" in data else f"{sample_id}_pred.wav"
30
+ out_file = Path(args.results_path) / fname
31
+
32
+ sf.write(
33
+ out_file.as_posix(),
34
+ pred_wav.detach().cpu().numpy(),
35
+ args.sample_rate,
36
+ )
37
+
38
+
39
+ def load_data(in_file):
40
+ with open(in_file) as f:
41
+ data = [ast.literal_eval(line.strip()) for line in f]
42
+
43
+ return data
44
+
45
+
46
+ def get_f0_upsample_ratio(code_hop_size, f_hop_size):
47
+ ratio = (code_hop_size // 160) // (f_hop_size // 256) * 2
48
+ return ratio
49
+
50
+
51
+ def main(args):
52
+ logger.info(args)
53
+
54
+ use_cuda = torch.cuda.is_available() and not args.cpu
55
+
56
+ with open(args.vocoder_cfg) as f:
57
+ vocoder_cfg = json.load(f)
58
+ vocoder = CodeHiFiGANVocoder(args.vocoder, vocoder_cfg)
59
+ if use_cuda:
60
+ vocoder = vocoder.cuda()
61
+
62
+ data = load_data(args.in_file)
63
+
64
+ if args.results_path:
65
+ Path(args.results_path).mkdir(exist_ok=True, parents=True)
66
+
67
+ for i, d in tqdm(enumerate(data), total=len(data)):
68
+ code_key = "cpc_km100" if "cpc_km100" in d else "hubert"
69
+ code = list(map(int, d[code_key].split()))
70
+
71
+ x = {
72
+ "code": torch.LongTensor(code).view(1, -1),
73
+ "f0": torch.Tensor(d["f0"]).view(1, -1),
74
+ }
75
+
76
+ f0_up_ratio = get_f0_upsample_ratio(
77
+ vocoder_cfg["code_hop_size"], vocoder_cfg["hop_size"]
78
+ )
79
+ if f0_up_ratio > 1:
80
+ bsz, cond_length = x["f0"].size()
81
+ x["f0"] = x["f0"].unsqueeze(2).repeat(1, 1, f0_up_ratio).view(bsz, -1)
82
+
83
+ x = utils.move_to_cuda(x) if use_cuda else x
84
+ wav = vocoder(x)
85
+ dump_result(args, d, i, wav)
86
+
87
+
88
+ def cli_main():
89
+ parser = argparse.ArgumentParser()
90
+ parser.add_argument(
91
+ "--in-file",
92
+ type=str,
93
+ required=True,
94
+ help="Input file following the same format of the output from sample.py ('f0' and 'cpc_km100/hubert' are required fields)",
95
+ )
96
+ parser.add_argument(
97
+ "--vocoder", type=str, required=True, help="path to the vocoder"
98
+ )
99
+ parser.add_argument(
100
+ "--vocoder-cfg",
101
+ type=str,
102
+ required=True,
103
+ help="path to the vocoder config",
104
+ )
105
+ parser.add_argument("--sample-rate", type=int, default=16_000)
106
+ parser.add_argument(
107
+ "--results-path",
108
+ type=str,
109
+ default=None,
110
+ help="Output directory. If not set, the audios will be stored following the 'audio' field specified in the input file.",
111
+ )
112
+ parser.add_argument("--cpu", action="store_true", help="run on CPU")
113
+
114
+ args = parser.parse_args()
115
+
116
+ main(args)
117
+
118
+
119
+ if __name__ == "__main__":
120
+ cli_main()
fairseq/examples/textless_nlp/pgslm/inference_dataset.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import torch
8
+
9
+
10
+ class InferenceDataset:
11
+ def __init__(
12
+ self,
13
+ dataset,
14
+ prefix,
15
+ only_prefix=True,
16
+ presort_by_length=True,
17
+ filter_short=False,
18
+ min_length=None,
19
+ ):
20
+ self.dataset = dataset
21
+ self.collater = self.dataset.collater
22
+ self.prefix = prefix
23
+ self.only_prefix = only_prefix
24
+ self.filter_short = filter_short
25
+
26
+ self.remapping = list(range(len(self.dataset)))
27
+ if min_length:
28
+ assert min_length >= prefix + 1
29
+
30
+ length_thr = prefix + 1 if not min_length else min_length
31
+
32
+ if filter_short:
33
+ self.remapping = list(
34
+ filter(
35
+ lambda i: self.dataset[i]["dur_source"].sum() > length_thr,
36
+ self.remapping,
37
+ )
38
+ )
39
+ print(
40
+ f"# the initial dataset of {len(self.dataset)} examples became {len(self.remapping)} after filtering"
41
+ f" examples shorter than {length_thr} (in duration units)"
42
+ )
43
+
44
+ if presort_by_length:
45
+ lengths = {index: dataset.size(index) for index in self.remapping}
46
+ self.remapping.sort(key=lambda i: lengths[i])
47
+
48
+ @property
49
+ def pads(self):
50
+ return self.dataset.pads
51
+
52
+ def __len__(self):
53
+ return len(self.remapping)
54
+
55
+ def original_size(self, k):
56
+ k = self.remapping[k]
57
+ return self.dataset.size(k)
58
+
59
+ def __getitem__(self, k):
60
+ k = self.remapping[k]
61
+ channels = self.dataset[k]
62
+
63
+ if self.prefix and self.only_prefix:
64
+ dur_channel = channels["dur_source"]
65
+ assert dur_channel.sum() >= self.prefix
66
+
67
+ token_times = dur_channel.cumsum(dim=-1)
68
+ cut_after = torch.searchsorted(token_times, torch.tensor(self.prefix))
69
+
70
+ r = {}
71
+ for channel_name, value in channels.items():
72
+ if isinstance(value, torch.Tensor) and "source" in channel_name:
73
+ # if self.filter_short: assert value.size(0) >= self.prefix
74
+ r[channel_name] = value[: cut_after + 1]
75
+ else:
76
+ r[channel_name] = value
77
+
78
+ r["prefix"] = cut_after + 1
79
+ else:
80
+ r = channels
81
+
82
+ return r
83
+
84
+
85
+ def explode_batch(batch, times):
86
+ if times == 1:
87
+ return batch
88
+
89
+ new_batch = {}
90
+
91
+ for key, value in batch.items():
92
+ if isinstance(value, torch.Tensor):
93
+ assert value.size(0) == 1
94
+ new_batch[key] = torch.cat([value] * times)
95
+ elif key in ["ntokens", "nsentences"]:
96
+ new_batch[key] = value * times
97
+ elif key in ["prefix", "filename"]:
98
+ new_batch[key] = value
99
+ elif key == "net_input":
100
+ new_batch[key] = explode_batch(value, times)
101
+ else:
102
+ assert False, key
103
+ return new_batch
fairseq/examples/textless_nlp/pgslm/naive_decoder.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import warnings
8
+
9
+
10
+ class Naive_F0_Decoder(torch.nn.Module):
11
+ def __init__(self, bounds_path, n_units=32):
12
+ super().__init__()
13
+
14
+ bounds = torch.load(bounds_path)
15
+ bounds = torch.from_numpy(bounds[n_units])
16
+ assert bounds.ndim == 1
17
+
18
+ pad = torch.tensor([-5.0, -5.0]) # bos, eos, pad are in the dictionary
19
+ centers = torch.cat(
20
+ [bounds[0:1], 0.5 * (bounds[1:] + bounds[:-1]), bounds[-1:], pad[:]]
21
+ )
22
+
23
+ self.embedding = torch.nn.Embedding.from_pretrained(
24
+ centers.unsqueeze(-1), freeze=True
25
+ )
26
+ self.max_n = self.embedding.weight.numel()
27
+
28
+ def forward(self, discrete_f0: torch.Tensor):
29
+ in_bounds = (0 <= discrete_f0).all() and (discrete_f0 < self.max_n).all()
30
+ if not in_bounds:
31
+ warnings.warn(
32
+ f"F0 contains some weird outputs: discrete_f0.max().item()={discrete_f0.max().item()} discrete_f0.min().item()={discrete_f0.min().item()}; "
33
+ f"while we have embeddings for {self.max_n} values. "
34
+ "Assuming this is a no-prosody model -- but be careful!"
35
+ )
36
+
37
+ mask = discrete_f0 >= self.max_n
38
+ discrete_f0 = discrete_f0.masked_fill(mask, self.max_n - 1)
39
+
40
+ return self.embedding(discrete_f0).squeeze(-1)
fairseq/examples/textless_nlp/pgslm/prepare_dataset.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from multiprocessing import Pool
7
+
8
+ import os
9
+ from collections import defaultdict
10
+ from itertools import starmap
11
+
12
+ import torch
13
+ from npy_append_array import NpyAppendArray
14
+ from tqdm import tqdm
15
+
16
+ from data_utils import dump_speaker_f0_stat, F0Stat, load_f0
17
+ from fairseq.data.codedataset import (
18
+ ExpressiveCodeDataConfig,
19
+ parse_manifest,
20
+ F0_FRAME_SPACE,
21
+ align_f0_to_durations,
22
+ )
23
+ from fairseq.tasks.speech_ulm_task import UnitDictionary
24
+
25
+
26
+ def load_meta(meta_path, split):
27
+ config = ExpressiveCodeDataConfig(meta_path)
28
+ manifest_path = config.manifests[split]
29
+ dictionary = UnitDictionary(n_units=config.n_units)
30
+ audio_paths, codes, durs, speakers = parse_manifest(manifest_path, dictionary)
31
+ return config, audio_paths, codes, durs, speakers
32
+
33
+
34
+ def _align_f0(f0, dur, ratio, frm_tol=5):
35
+ if f0 is None:
36
+ seg_f0 = torch.zeros_like(dur, dtype=torch.float)
37
+ else:
38
+ seg_f0 = align_f0_to_durations(f0, dur, ratio, tol=frm_tol * ratio)
39
+ return seg_f0.numpy() # try a hacky stuff
40
+
41
+
42
+ def align_f0(path_to_f0, audio_paths, durs, ratio, mp=False):
43
+ chunk_size = 2000
44
+ num_procs = 40
45
+ iterable = ((path_to_f0[p], d, ratio) for p, d in zip(audio_paths, durs))
46
+
47
+ seg_f0s = []
48
+ if mp:
49
+ with Pool(num_procs) as pool:
50
+ iterator = tqdm(
51
+ pool.istarmap(_align_f0, iterable, chunk_size),
52
+ desc="align f0",
53
+ total=len(durs),
54
+ )
55
+ for seg_f0 in iterator:
56
+ seg_f0s.append(torch.from_numpy(seg_f0).float())
57
+ else:
58
+ iterator = tqdm(starmap(_align_f0, iterable), desc="align f0", total=len(durs))
59
+ for seg_f0 in iterator:
60
+ seg_f0s.append(torch.from_numpy(seg_f0).float())
61
+
62
+ return seg_f0s
63
+
64
+
65
+ def prepare_seg_data(config, audio_paths, codes, durs, speakers, path_to_f0):
66
+ ratio = config.code_hop_size / (config.sampling_rate * F0_FRAME_SPACE)
67
+ seg_f0s = align_f0(path_to_f0, audio_paths, durs, ratio)
68
+ data = {
69
+ "codes": codes,
70
+ "duration": durs,
71
+ "f0": seg_f0s,
72
+ "speaker": speakers,
73
+ "path": audio_paths,
74
+ }
75
+ return data
76
+
77
+
78
+ def dump_seg_data(data, out_prefix):
79
+ key_targs = {
80
+ "codes": f"{out_prefix}.code.npy",
81
+ "duration": f"{out_prefix}.dur.npy",
82
+ "f0": f"{out_prefix}.f0.npy",
83
+ }
84
+ for key, targ in key_targs.items():
85
+ assert not os.path.exists(targ)
86
+ npaa = NpyAppendArray(targ)
87
+ for utt_data in tqdm(data[key], desc=f"dumping {key}"):
88
+ npaa.append(utt_data.numpy())
89
+
90
+ assert not os.path.exists(f"{out_prefix}.path.txt")
91
+ with open(f"{out_prefix}.path.txt", "w") as f:
92
+ for x in data["path"]:
93
+ f.write(f"{str(x)}\n")
94
+
95
+ assert not os.path.exists(f"{out_prefix}.leng.txt")
96
+ with open(f"{out_prefix}.leng.txt", "w") as f:
97
+ for x in data["codes"]:
98
+ f.write(f"{len(x)}\n")
99
+
100
+ assert not os.path.exists(f"{out_prefix}.speaker.txt")
101
+ with open(f"{out_prefix}.speaker.txt", "w") as f:
102
+ for x in data["speaker"]:
103
+ f.write(f"{str(x)}\n")
104
+
105
+ print(f"wrote to files with prefix {out_prefix}")
106
+
107
+
108
+ def main(meta_path, f0_dir, splits, nshards_list):
109
+ speaker_to_stat = defaultdict(F0Stat)
110
+ if len(nshards_list) == 1:
111
+ nshards_list = nshards_list * len(splits)
112
+ else:
113
+ assert len(nshards_list) == len(splits)
114
+
115
+ for split, nshards in zip(splits, nshards_list):
116
+ config, audio_paths, codes, durs, speakers = load_meta(meta_path, split)
117
+ path_to_f0 = load_f0(f"{f0_dir}/{split}", nshards)
118
+
119
+ # segment-level data
120
+ data = prepare_seg_data(config, audio_paths, codes, durs, speakers, path_to_f0)
121
+ dump_seg_data(data, config.manifests[split])
122
+
123
+ # speaker f0
124
+ for audio_path, speaker in tqdm(zip(audio_paths, speakers)):
125
+ f0 = path_to_f0[audio_path]
126
+ speaker_to_stat[speaker].update(f0)
127
+ dump_speaker_f0_stat(speaker_to_stat, config.manifests[split])
128
+
129
+
130
+ if __name__ == "__main__":
131
+ import argparse
132
+
133
+ parser = argparse.ArgumentParser()
134
+ parser.add_argument("meta_path")
135
+ parser.add_argument("f0_dir", help="out_dir from preprocess_f0")
136
+ parser.add_argument("--splits", nargs="+", default=["train", "valid"])
137
+ parser.add_argument(
138
+ "--nshards_list", type=int, nargs="+", default=[20], help="number of f0 shards"
139
+ )
140
+ args = parser.parse_args()
141
+ print(args)
142
+
143
+ main(**vars(args))
fairseq/examples/textless_nlp/pgslm/preprocess_f0.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import torch
8
+ from tqdm import tqdm
9
+ from data_utils import load_audio_path
10
+ from fairseq.data.codedataset import get_f0_by_filename
11
+
12
+
13
+ def process_one(path, sr):
14
+ """
15
+ Args:
16
+ path: audio file path
17
+ sr: sampling rate
18
+ """
19
+ try:
20
+ # YAAPT throws errors in some rare cases
21
+ f0 = get_f0_by_filename(path, sr)
22
+ except Exception as e:
23
+ print(
24
+ f"WARNING: error when processing {path}. set f0 to zero. original error message:\n{e}"
25
+ )
26
+ f0 = None
27
+ return f0
28
+
29
+
30
+ def main(file_path, out_dir, nshards, rank, sampling_rate):
31
+ # load data
32
+ audio_paths = load_audio_path(file_path)
33
+
34
+ # shard
35
+ assert nshards <= len(audio_paths) and nshards > 0
36
+ shard_size = len(audio_paths) / nshards
37
+ s = int(round((rank - 1) * shard_size))
38
+ e = int(round(rank * shard_size))
39
+ audio_paths = audio_paths[s:e]
40
+
41
+ # process
42
+ path_to_f0 = {}
43
+ for i, audio_path in enumerate(tqdm(audio_paths)):
44
+ f0 = process_one(audio_path, sampling_rate)
45
+ path_to_f0[audio_path] = f0
46
+ print(f"finished processing {len(path_to_f0)} utterances ({s}-{e})")
47
+
48
+ f0_path = f"{out_dir}/f0_{rank}_{nshards}.pt"
49
+ os.makedirs(out_dir, exist_ok=True)
50
+ torch.save(path_to_f0, f0_path)
51
+ print(f"saved to {f0_path}")
52
+
53
+
54
+ if __name__ == "__main__":
55
+ import argparse
56
+
57
+ parser = argparse.ArgumentParser()
58
+ parser.add_argument("file_path")
59
+ parser.add_argument("out_dir")
60
+ parser.add_argument("--nshards", type=int, default=20)
61
+ parser.add_argument("--rank", type=int, default=1)
62
+ parser.add_argument("--sampling_rate", type=int, default=16000)
63
+ args = parser.parse_args()
64
+
65
+ main(**vars(args))
fairseq/examples/textless_nlp/pgslm/quantize_f0.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import defaultdict
7
+ from functools import partial
8
+
9
+ import numpy as np
10
+ import torch
11
+ from tqdm import tqdm
12
+
13
+ from data_utils import dump_speaker_f0_stat, F0Stat, load_audio_path, load_f0
14
+
15
+
16
+ def load_speaker(path):
17
+ speakers = []
18
+ with open(path) as f:
19
+ for line in f.readlines():
20
+ sample = eval(line.strip())
21
+ assert "speaker" in sample
22
+ speakers.append(sample["speaker"])
23
+ return speakers
24
+
25
+
26
+ def quantize_f0(speaker_to_f0, f0_stats, nbins, normalize, log):
27
+ f0_all = []
28
+ for speaker, f0 in speaker_to_f0.items():
29
+ f0 = f0.raw_data
30
+ if log:
31
+ f0 = f0.log()
32
+ mean = f0_stats[speaker]["logf0_mean"] if log else f0_stats[speaker]["f0_mean"]
33
+ std = f0_stats[speaker]["logf0_std"] if log else f0_stats[speaker]["f0_std"]
34
+ if normalize == "mean":
35
+ f0 = f0 - mean
36
+ elif normalize == "meanstd":
37
+ f0 = (f0 - mean) / std
38
+ f0_all.extend(f0.tolist())
39
+
40
+ hist, bin_x = np.histogram(f0_all, 100000)
41
+ cum_hist = np.cumsum(hist) / len(f0_all) * 100
42
+
43
+ f0_bin = {}
44
+ for num_bin in nbins:
45
+ bin_offset = []
46
+ bin_size = 100 / num_bin
47
+ threshold = bin_size
48
+ for i in range(num_bin - 1):
49
+ index = (np.abs(cum_hist - threshold)).argmin()
50
+ bin_offset.append(bin_x[index])
51
+ threshold += bin_size
52
+ f0_bin[num_bin] = np.array(bin_offset)
53
+
54
+ return f0_bin
55
+
56
+
57
+ def main(file_path, f0_dir, out_dir, out_prefix, nbins, nshards, normalize, log):
58
+ audio_paths = load_audio_path(file_path)
59
+ path_to_f0 = load_f0(f0_dir, nshards)
60
+
61
+ speakers = load_speaker(file_path)
62
+ speaker_to_f0 = defaultdict(partial(F0Stat, True))
63
+
64
+ # speaker f0 stats
65
+ for audio_path, speaker in tqdm(zip(audio_paths, speakers)):
66
+ f0 = path_to_f0[audio_path]
67
+ speaker_to_f0[speaker].update(f0)
68
+ f0_stats = dump_speaker_f0_stat(speaker_to_f0, f"{out_dir}/{out_prefix}")
69
+
70
+ # quantize
71
+ f0_bin = quantize_f0(speaker_to_f0, f0_stats, nbins, normalize, log)
72
+ log_suffix = "_log" if log else ""
73
+ f0_bin_out_file = f"{out_dir}/{out_prefix}_{normalize}_norm{log_suffix}_f0_bin.th"
74
+ torch.save(f0_bin, f0_bin_out_file)
75
+
76
+
77
+ if __name__ == "__main__":
78
+ import argparse
79
+
80
+ parser = argparse.ArgumentParser()
81
+ parser.add_argument("file_path")
82
+ parser.add_argument("f0_dir", help="out_dir from preprocess_f0")
83
+ parser.add_argument("out_dir")
84
+ parser.add_argument("out_prefix")
85
+ parser.add_argument("--nbins", nargs="+", type=int, default=[32])
86
+ parser.add_argument("--nshards", type=int, default=20, help="number of f0 shards")
87
+ parser.add_argument(
88
+ "--normalize", type=str, choices=["meanstd", "mean", "none"], default="mean"
89
+ )
90
+ parser.add_argument("--log", action="store_true")
91
+ args = parser.parse_args()
92
+ print(args)
93
+
94
+ main(**vars(args))