Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- fairseq/examples/hubert/tests/sample.xlarge.L30.npy +3 -0
- fairseq/examples/textless_nlp/dgslm/hubert_fisher/README.md +47 -0
- fairseq/examples/textless_nlp/dgslm/vocoder_hifigan/README.md +47 -0
- fairseq/examples/textless_nlp/gslm/README.md +21 -0
- fairseq/examples/textless_nlp/gslm/metrics/README.md +10 -0
- fairseq/examples/textless_nlp/gslm/metrics/abx_metrics/dump_abx_feats.py +107 -0
- fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/README.md +87 -0
- fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/self_auto_bleu.py +201 -0
- fairseq/examples/textless_nlp/gslm/speech2unit/README.md +68 -0
- fairseq/examples/textless_nlp/gslm/speech2unit/clustering/dump_feats.py +91 -0
- fairseq/examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py +141 -0
- fairseq/examples/textless_nlp/gslm/speech2unit/clustering/utils.py +20 -0
- fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/cpc_feature_reader.py +204 -0
- fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/hubert_feature_reader.py +70 -0
- fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/logmel_feature_reader.py +34 -0
- fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/utils.py +127 -0
- fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/w2v2_feature_reader.py +56 -0
- fairseq/examples/textless_nlp/gslm/tools/README.md +25 -0
- fairseq/examples/textless_nlp/gslm/tools/resynthesize_speech.py +132 -0
- fairseq/examples/textless_nlp/gslm/ulm/README.md +72 -0
- fairseq/examples/textless_nlp/gslm/ulm/sample.py +174 -0
- fairseq/examples/textless_nlp/gslm/unit2speech/README.md +40 -0
- fairseq/examples/textless_nlp/gslm/unit2speech/convert_to_16k.py +56 -0
- fairseq/examples/textless_nlp/gslm/unit2speech/glow.py +312 -0
- fairseq/examples/textless_nlp/gslm/unit2speech/multiproc.py +27 -0
- fairseq/examples/textless_nlp/gslm/unit2speech/synthesize_audio_from_units.py +105 -0
- fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/__init__.py +0 -0
- fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/audio_processing.py +93 -0
- fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cleaners.py +90 -0
- fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cmudict.py +65 -0
- fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/layers.py +103 -0
- fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/model.py +669 -0
- fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/numbers.py +71 -0
- fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/stft.py +141 -0
- fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/symbols.py +18 -0
- fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/text.py +107 -0
- fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/utils.py +171 -0
- fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/waveglow_denoiser.py +40 -0
- fairseq/examples/textless_nlp/gslm/unit2speech/tts_data.py +54 -0
- fairseq/examples/textless_nlp/gslm/unit2speech/utils.py +55 -0
- fairseq/examples/textless_nlp/pgslm/README.md +318 -0
- fairseq/examples/textless_nlp/pgslm/data_utils.py +107 -0
- fairseq/examples/textless_nlp/pgslm/eval/__init__.py +4 -0
- fairseq/examples/textless_nlp/pgslm/eval/cont_metrics.py +730 -0
- fairseq/examples/textless_nlp/pgslm/generate_waveform.py +120 -0
- fairseq/examples/textless_nlp/pgslm/inference_dataset.py +103 -0
- fairseq/examples/textless_nlp/pgslm/naive_decoder.py +40 -0
- fairseq/examples/textless_nlp/pgslm/prepare_dataset.py +143 -0
- fairseq/examples/textless_nlp/pgslm/preprocess_f0.py +65 -0
- fairseq/examples/textless_nlp/pgslm/quantize_f0.py +94 -0
fairseq/examples/hubert/tests/sample.xlarge.L30.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bbe9f0929ecd4c58786be55215d83f85787ff0d81196bc2e73414f82a8939806
|
3 |
+
size 3051712
|
fairseq/examples/textless_nlp/dgslm/hubert_fisher/README.md
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dialogue Speech-to-Unit Encoder for dGSLM: The Fisher HuBERT model
|
2 |
+
For the speech2unit encoder, we train a [HuBERT model](https://arxiv.org/pdf/2106.07447.pdf) on the [Fisher dataset](http://www.lrec-conf.org/proceedings/lrec2004/pdf/767.pdf) for 3 iterations (see [our paper](https://arxiv.org/pdf/2203.16502.pdf) for more details) and train a k-means model with 500 units on the layer 12 features of the HuBERT model.
|
3 |
+
|
4 |
+
## Model checkpoints
|
5 |
+
The pre-trained HuBERT and k-means model checkpoints can be found here:
|
6 |
+
|
7 |
+
| Fisher HuBERT model | k-means model |
|
8 |
+
|---------------------|---------------|
|
9 |
+
|[download](https://dl.fbaipublicfiles.com/textless_nlp/dgslm/checkpoints/hubert/hubert_fisher.pt)|[download](https://dl.fbaipublicfiles.com/textless_nlp/dgslm/checkpoints/hubert/hubert_fisher_km_500.bin)|
|
10 |
+
|
11 |
+
|
12 |
+
## Encode audio to discrete units
|
13 |
+
Below is an example command to encode a stereo dataset to discrete units using the pre-trained model checkpoints :
|
14 |
+
```bash
|
15 |
+
for CHANNEL_ID in 1 2; do
|
16 |
+
python examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py \
|
17 |
+
--feature_type hubert \
|
18 |
+
--kmeans_model_path path/to/hubert_fisher_km_500.bin \
|
19 |
+
--acoustic_model_path path/to/hubert_fisher.pt \
|
20 |
+
--layer 12 \
|
21 |
+
--manifest_path $MANIFEST_FILE \
|
22 |
+
--out_quantized_file_path ${OUTPUT_FILE}-channel${CHANNEL_ID} \
|
23 |
+
--extension $EXTENSION \
|
24 |
+
--channel_id $CHANNEL_ID
|
25 |
+
done
|
26 |
+
```
|
27 |
+
where MANIFEST_FILE is the output of [wav2vec manifest script](https://github.com/facebookresearch/fairseq/blob/main/examples/wav2vec/wav2vec_manifest.py), which can be obtained through the following command :
|
28 |
+
```
|
29 |
+
python examples/wav2vec/wav2vec_manifest.py --valid-percent=0.0 $AUDIO_DIR --dest=$OUTPUT_DIR --ext=$EXTENSION
|
30 |
+
```
|
31 |
+
|
32 |
+
Otherwise, you can encode an audio file in python interactively with the HubertTokenizer class :
|
33 |
+
```python
|
34 |
+
# Load the Hubert tokenizer
|
35 |
+
from examples.textless_nlp.dgslm.dgslm_utils import HubertTokenizer
|
36 |
+
encoder = HubertTokenizer(
|
37 |
+
hubert_path = "/path/to/hubert_ckpt.pt",
|
38 |
+
hubert_layer = 12,
|
39 |
+
km_path = "path/to/km.bin"
|
40 |
+
)
|
41 |
+
|
42 |
+
# Encode the audio to units
|
43 |
+
path = "/path/to/stereo/audio.wav"
|
44 |
+
codes = encoder.wav2codes(path)
|
45 |
+
# > ['7 376 376 133 178 486 486 486 486 486 486 486 486 2 486',
|
46 |
+
# > '7 499 415 177 7 7 7 7 7 7 136 136 289 289 408']
|
47 |
+
```
|
fairseq/examples/textless_nlp/dgslm/vocoder_hifigan/README.md
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dialogue Unit-to-Speech Decoder for dGSLM
|
2 |
+
For the unit2speech decoder, we train a [discrete unit-based HiFi-GAN vocoder](https://arxiv.org/pdf/2104.00355.pdf) on the [Fisher dataset](http://www.lrec-conf.org/proceedings/lrec2004/pdf/767.pdf).
|
3 |
+
|
4 |
+
## Model checkpoint
|
5 |
+
The pre-trained model checkpoint can be found here :
|
6 |
+
|
7 |
+
| HiFi-GAN vocoder based on HuBERT Fisher Units |
|
8 |
+
|-----------------------------------------------|
|
9 |
+
|[model checkpoint](https://dl.fbaipublicfiles.com/textless_nlp/dgslm/checkpoints/hifigan/hifigan_vocoder) - [config](https://dl.fbaipublicfiles.com/textless_nlp/dgslm/checkpoints/hifigan/config.json) |
|
10 |
+
|
11 |
+
## Decode discrete units to audio
|
12 |
+
To create waveform from discrete units, use the script `generate_stereo_waveform.py` :
|
13 |
+
```bash
|
14 |
+
python examples/textless_nlp/dgslm/vocoder_hifigan/generate_stereo_waveform.py \
|
15 |
+
--in-file $INPUT_CODE_FILE \
|
16 |
+
--vocoder $VOCODER_PATH \
|
17 |
+
--vocoder-cfg $VOCODER_CONFIG \
|
18 |
+
--results-path $OUTPUT_DIR
|
19 |
+
```
|
20 |
+
where INPUT_CODE_FILE is expected to have the following format :
|
21 |
+
```
|
22 |
+
{'audio': 'file_1', 'unitA': '8 8 ... 352 352', 'unitB': '217 8 ... 8 8'}
|
23 |
+
{'audio': 'file_2', 'unitA': '5 5 ... 65 65', 'unitB': '6 35 ... 8 9'}
|
24 |
+
...
|
25 |
+
```
|
26 |
+
|
27 |
+
You can also use the HifiganVocoder class to generate waveform from the codes interactively :
|
28 |
+
```python
|
29 |
+
# Load the Hifigan vocoder
|
30 |
+
from examples.textless_nlp.dgslm.dgslm_utils import HifiganVocoder
|
31 |
+
decoder = HifiganVocoder(
|
32 |
+
vocoder_path = "/path/to/hifigan_vocoder",
|
33 |
+
vocoder_cfg_path = "/path/to/config.json",
|
34 |
+
)
|
35 |
+
|
36 |
+
# Decode the units to waveform
|
37 |
+
codes = [
|
38 |
+
'7 376 376 133 178 486 486 486 486 486 486 486 486 2 486',
|
39 |
+
'7 499 415 177 7 7 7 7 7 7 136 136 289 289 408',
|
40 |
+
]
|
41 |
+
wav = decoder.codes2wav(codes)
|
42 |
+
# > array of shape (2, 4800)
|
43 |
+
|
44 |
+
# Play the waveform
|
45 |
+
import IPython.display as ipd
|
46 |
+
ipd.Audio(wav, rate=16_000)
|
47 |
+
```
|
fairseq/examples/textless_nlp/gslm/README.md
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Generative Spoken Language Modeling
|
2 |
+
|
3 |
+
* [Paper](https://arxiv.org/abs/2102.01192)
|
4 |
+
* [Demo](https://speechbot.github.io/gslm/index.html)
|
5 |
+
|
6 |
+
We build and evaluate generative speech2speech systems using [Log Mel Filtebank](https://pytorch.org/audio/stable/compliance.kaldi.html#fbank), [Modified CPC](https://github.com/facebookresearch/CPC_audio), [HuBERT Base](https://github.com/pytorch/fairseq/tree/main/examples/hubert) and [Wav2Vec 2.0 Large](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec). Our system is composed of three components, namely, *speech2unit*, *ulm* and *unit2speech*. We explain about models and usage of these components in their respective sub-directories. See the links below.
|
7 |
+
|
8 |
+
## Speech to Unit Model (speech2unit)
|
9 |
+
Speech to unit model is used for quantizing raw speech into learned discrete speech units. [More details](speech2unit)
|
10 |
+
|
11 |
+
## Unit Language Model (ulm)
|
12 |
+
Unit Language Model is a generative language model trained on discrete speech units. [More details](ulm)
|
13 |
+
|
14 |
+
## Unit to Speech Model (unit2speech)
|
15 |
+
Unit to speech model is used for synthesizing speech from discrete speech units. [More details](unit2speech)
|
16 |
+
|
17 |
+
## Metrics
|
18 |
+
We show how to compute ASR based metrics as well as zero-shot metrics proposed in our paper [here](metrics).
|
19 |
+
|
20 |
+
## Tools
|
21 |
+
We share two tools to resynthesize a given spoken utterance, and generate novel spoken language given a spoken prompt. [More detail](tools)
|
fairseq/examples/textless_nlp/gslm/metrics/README.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GSLM Metrics
|
2 |
+
|
3 |
+
## ASR Metrics
|
4 |
+
The suite of metrics here uses an ASR model to transcribe the synthesized speech into text, and then uses text-based metrics. We also use word error rate from ASR transcription itself as one of the metrics. [More details](asr_metrics)
|
5 |
+
|
6 |
+
## ABX Metrics
|
7 |
+
We use [ABX](https://www.semanticscholar.org/paper/ABX-Discriminability-Measures-and-Applications-Schatz/13d3537228f728c1063cc83743cb118bba3367a0) to evaluate how well-separated phonetic categories are with quantized representations. [More details](abx_metrics)
|
8 |
+
|
9 |
+
## sWUGGY and sBLIMP
|
10 |
+
We refer to [ZeroSpeech challenge](https://www.zerospeech.com/2021/track_s.html#scoring-based-metrics) for details on the sWUGGY and sBLIMP metrics.
|
fairseq/examples/textless_nlp/gslm/metrics/abx_metrics/dump_abx_feats.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
|
10 |
+
import joblib
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
from examples.textless_nlp.gslm.speech2unit.clustering.utils import get_audio_files
|
14 |
+
from examples.textless_nlp.gslm.speech2unit.pretrained.utils import get_features
|
15 |
+
|
16 |
+
def get_logger():
|
17 |
+
log_format = "[%(asctime)s] [%(levelname)s]: %(message)s"
|
18 |
+
logging.basicConfig(format=log_format, level=logging.INFO)
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
return logger
|
21 |
+
|
22 |
+
def get_parser():
|
23 |
+
parser = argparse.ArgumentParser(
|
24 |
+
description="Quantize using K-means clustering over acoustic features."
|
25 |
+
)
|
26 |
+
parser.add_argument(
|
27 |
+
"--feature_type",
|
28 |
+
type=str,
|
29 |
+
choices=["logmel", "hubert", "w2v2", "cpc"],
|
30 |
+
default=None,
|
31 |
+
required=True,
|
32 |
+
help="Acoustic feature type",
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"--kmeans_model_path",
|
36 |
+
type=str,
|
37 |
+
required=True,
|
38 |
+
help="K-means model file path to use for inference",
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--manifest_path",
|
42 |
+
type=str,
|
43 |
+
default=None,
|
44 |
+
help="Manifest file containing the root dir and file names",
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"--checkpoint_path",
|
48 |
+
type=str,
|
49 |
+
help="Pretrained model checkpoint",
|
50 |
+
)
|
51 |
+
parser.add_argument(
|
52 |
+
"--layer",
|
53 |
+
type=int,
|
54 |
+
help="The layer of the pretrained model to extract features from",
|
55 |
+
default=-1,
|
56 |
+
)
|
57 |
+
parser.add_argument(
|
58 |
+
"--out_dir_path",
|
59 |
+
required=True,
|
60 |
+
type=str,
|
61 |
+
help="File path of quantized output.",
|
62 |
+
)
|
63 |
+
parser.add_argument(
|
64 |
+
"--extension", type=str, default=".flac", help="Features file path"
|
65 |
+
)
|
66 |
+
return parser
|
67 |
+
|
68 |
+
|
69 |
+
def one_hot(feat, n_clusters):
|
70 |
+
return np.eye(n_clusters)[feat]
|
71 |
+
|
72 |
+
def main(args, logger):
|
73 |
+
# Feature extraction
|
74 |
+
logger.info(f"Extracting {args.feature_type} acoustic features...")
|
75 |
+
features_batch = get_features(
|
76 |
+
feature_type=args.feature_type,
|
77 |
+
checkpoint_path=args.checkpoint_path,
|
78 |
+
layer=args.layer,
|
79 |
+
manifest_path=args.manifest_path,
|
80 |
+
sample_pct=1.0,
|
81 |
+
flatten=False,
|
82 |
+
)
|
83 |
+
logger.info(f"Features extracted for {len(features_batch)} utterances.\n")
|
84 |
+
logger.info(f"Dimensionality of representation = {features_batch[0].shape[1]}")
|
85 |
+
|
86 |
+
logger.info(f"Loading K-means model from {args.kmeans_model_path} ...")
|
87 |
+
kmeans_model = joblib.load(open(args.kmeans_model_path, "rb"))
|
88 |
+
kmeans_model.verbose = False
|
89 |
+
|
90 |
+
_, fnames, _ = get_audio_files(args.manifest_path)
|
91 |
+
|
92 |
+
os.makedirs(args.out_dir_path, exist_ok=True)
|
93 |
+
logger.info(f"Writing quantized features to {args.out_dir_path}")
|
94 |
+
for i, feats in enumerate(features_batch):
|
95 |
+
pred = kmeans_model.predict(feats)
|
96 |
+
emb = one_hot(pred, kmeans_model.n_clusters)
|
97 |
+
base_fname = os.path.basename(fnames[i]).rstrip(args.extension)
|
98 |
+
output_path = os.path.join(args.out_dir_path, f"{base_fname}.npy")
|
99 |
+
with open(output_path, "wb") as f:
|
100 |
+
np.save(f, emb)
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
parser = get_parser()
|
104 |
+
args = parser.parse_args()
|
105 |
+
logger = get_logger()
|
106 |
+
logger.info(args)
|
107 |
+
main(args, logger)
|
fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/README.md
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ASR-based evaluation
|
2 |
+
|
3 |
+
Overall, the life cycle of the ASR-based evaluation for an ULM contains the following steps:
|
4 |
+
1. Training an ULM and sampling from it [[description]](./../../ulm)
|
5 |
+
2. Running UTS on the sampled unit sequences [[description]](./../../unit2speech)
|
6 |
+
3. Pre-processing for the ASR (down-sampling to 16 KHz, aligning length of the generated audio with ground-truth utterances)
|
7 |
+
4. Running ASR
|
8 |
+
5. Calculation of the post-ASR evaluation metrics
|
9 |
+
|
10 |
+
Here we assume that you have already went throught the first two steps and focus on the rest.
|
11 |
+
|
12 |
+
## Preprocessing
|
13 |
+
### Down-sampling to 16KHz
|
14 |
+
The bulk conversion can be done by running
|
15 |
+
```bash
|
16 |
+
python $FAIRSEQ_ROOT/examples/textless_nlp/gslm/unit2speech/convert_to_16k.py $UTS_OUTPUT $UTS_OUTPUT_DOWNSAMPLE
|
17 |
+
```
|
18 |
+
where `$UTS_OUTPUT` specifies the directory with the generated audio and `$UTS_OUTPUT_DOWNSAMPLE` is the directory where downsampled audio would be saved.
|
19 |
+
|
20 |
+
### Matching by length
|
21 |
+
This step is somewhat optional. However, if you want to compare the fluency and diversity of a generated speech utterance to that of the ground-truth speech with the same prefix, it is a good idea to force them to be of the same length.
|
22 |
+
```bash
|
23 |
+
python $FAIRSEQ_ROOT/examples/textless_nlp/asr_metrics/cut_as.py \
|
24 |
+
--samples_dir=$UTS_OUTPUT_DOWNSAMPLE --out_dir=$UTS_OUTPUT_DOWNSAMPLE_CUT \
|
25 |
+
--prompts_description=data/ground_truth_continuation_dev.json
|
26 |
+
```
|
27 |
+
|
28 |
+
Here `ground_truth_continuation_dev.json` is a json file with ground-truth text from LibriSpeech dev-clean, associated with some meta-data (assuming the evaluation is done on dev-clean). This file can be downloaded [[here]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/ground_truth_continuation_dev.json). A similar file for the test-clean is [[here]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/ground_truth_continuation_test.json). These files are used for the evaluation and contain texts for audio sequences that are at least 6s long.
|
29 |
+
|
30 |
+
## Running ASR
|
31 |
+
We use a pre-trained wav2vec model to run the ASR step. We firstly need to prepare manifest files which, roughly, tell the ASR system which files we want to transcribe. You can find more details and download the `960h_scratch.pt` checkpoint
|
32 |
+
[[here]](https://github.com/pytorch/fairseq/blob/main/examples/wav2vec/README.md)). To run ASR, you would also need to
|
33 |
+
install KenLM, Flashlight decoder, and download the KenLM 4-gram English language model.
|
34 |
+
|
35 |
+
```bash
|
36 |
+
python $FAIRSEQ_ROOT/examples/wav2vec/wav2vec_manifest.py \
|
37 |
+
$UTS_OUTPUT_DOWNSAMPLE_CUT --valid-percent 0.0 --dest $MANIFEST_DIR --ext wav
|
38 |
+
```
|
39 |
+
where `$UTS_OUTPUT_DOWNSAMPLE_CUT` speficies the directory with the preprocessed UTS outputs and `$MANIFEST_DIR` is the output directory.
|
40 |
+
|
41 |
+
We will be running an out-of-the-box evaluation script which requires ground-truth transcripts to measure quality metrics. We are only
|
42 |
+
interested in the transcripts (and we don't have ground-truth outputs for when our ULM generated!), hence we will just generate
|
43 |
+
some dummy transcripts instead:
|
44 |
+
```bash
|
45 |
+
cp $FAIRSEQ_ROOT/examples/textless_nlp/gslm/asr_metrics/misc/dict.ltr.txt $MANIFEST_DIR
|
46 |
+
python $FAIRSEQ_ROOT/examples/textless_nlp/gslm/asr_metrics/misc/dummy_asr_data.py --tsv=$MANIFEST_DIR/train.tsv \
|
47 |
+
--output-dir=$MANIFEST_DIR
|
48 |
+
```
|
49 |
+
|
50 |
+
Now we are ready for running ASR:
|
51 |
+
```
|
52 |
+
mkdir -p asr
|
53 |
+
python $FAIRSEQ_ROOT/examples/speech_recognition/infer.py \
|
54 |
+
$MANIFEST_DIR \
|
55 |
+
--task audio_pretraining --nbest 1 --path 960h_scratch.pt \
|
56 |
+
--gen-subset=train --results-path $PATH_TO_ASR_OUTPUT \
|
57 |
+
--w2l-decoder kenlm --lm-model 4-gram.bin \
|
58 |
+
--lexicon librispeech/lexicon_ltr.lst --word-score -1 \
|
59 |
+
--sil-weight 0 --lm-weight 2 --criterion ctc --labels ltr --max-tokens 300000 --remove-bpe letter
|
60 |
+
```
|
61 |
+
where `lexicon_ltr.lst` is the LibriSpeech lexicon and `$PATH_TO_ASR_OUTPUT` is the output directory (can be downloaded [[here]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/lexicon_ltr.lst)).
|
62 |
+
|
63 |
+
## Evaluation metrics
|
64 |
+
We run evaluation on the 1_000 shortest sequences that are at least 6s long. To filter those from the ASR transcript, we additionally provide each metric script with the paths to the manifest and `ground_truth_continuation_*` files.
|
65 |
+
|
66 |
+
### Perplexity (PPX)
|
67 |
+
To get a PPX metric estimate on an ASR transcript, you need to run the following command:
|
68 |
+
```bash
|
69 |
+
python ppx.py $PATH_TO_ASR_OUTPUT/hypo.word-960h_scratch.pt-train.txt --cut-tail\
|
70 |
+
--manifest=$MANIFEST_DIR/train.tsv --prompts-description=data/ground_truth_continuation_dev.json
|
71 |
+
```
|
72 |
+
where `--cut-tail` tells the script to ignore the last token on each line (ASR puts the sequence ID there).
|
73 |
+
|
74 |
+
### Self- and Auto-BLEU
|
75 |
+
```bash
|
76 |
+
python self_bleu.py $PATH_TO_ASR_OUTPUT/hypo.word-960h_scratch.pt-train.txt --cut-tail \
|
77 |
+
--manifest=$MANIFEST_DIR/train.tsv --prompts-description=data/ground_truth_continuation_dev.json
|
78 |
+
```
|
79 |
+
|
80 |
+
### Continuation-BLEU
|
81 |
+
```bash
|
82 |
+
python continuation_eval.py --asr-transcript $PATH_TO_ASR_OUTPUT/hypo.word-960h_scratch.pt-train.txt \
|
83 |
+
--manifest=$MANIFEST_DIR/train.tsv --prompts-description=data/ground_truth_continuation_dev.json
|
84 |
+
```
|
85 |
+
|
86 |
+
### AUC
|
87 |
+
Based on the metrics calculated above, we can estimate the AUC of the perplexity/diversity trade-off. We provide an illustration in a [Colab notebook](https://colab.research.google.com/drive/1pVPfOVax_PU3MkYdHRSsa-SI8GBUldNt?usp=sharing).
|
fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/self_auto_bleu.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import nltk
|
8 |
+
from misc.bleu_utils import sentence_bleu
|
9 |
+
import warnings
|
10 |
+
|
11 |
+
|
12 |
+
def get_target_sequences(manifest, ground_truth, to_take=1000):
|
13 |
+
import json
|
14 |
+
import pathlib
|
15 |
+
|
16 |
+
with open(ground_truth, 'r') as fin:
|
17 |
+
original_continuations = json.loads(fin.read())
|
18 |
+
|
19 |
+
sequence2length = [(k, v[0]) for k, v in original_continuations.items()]
|
20 |
+
assert all(float(v) >= 6.0 for (_, v) in sequence2length) # 6 seconds
|
21 |
+
|
22 |
+
sequence2length.sort(key=lambda x: x[1])
|
23 |
+
to_take_sequences = set(v[0] for v in sequence2length[:to_take])
|
24 |
+
to_take_ids = []
|
25 |
+
|
26 |
+
with open(manifest, 'r') as f:
|
27 |
+
f.readline()
|
28 |
+
|
29 |
+
for i, line in enumerate(f.readlines()):
|
30 |
+
seq_id = line.split()[0]
|
31 |
+
seq_id = pathlib.Path(seq_id).name.split('__')[0]
|
32 |
+
|
33 |
+
if seq_id in to_take_sequences:
|
34 |
+
to_take_ids.append(i)
|
35 |
+
|
36 |
+
print(f'Took {len(to_take_ids)} ids')
|
37 |
+
return set(to_take_ids)
|
38 |
+
|
39 |
+
|
40 |
+
def get_args():
|
41 |
+
import argparse
|
42 |
+
|
43 |
+
parser = argparse.ArgumentParser()
|
44 |
+
parser.add_argument('--asr-transcript', type=str,
|
45 |
+
help='Path to the transcript file.')
|
46 |
+
|
47 |
+
parser.add_argument('--manifest', required=True)
|
48 |
+
parser.add_argument('--prompts-description', required=True)
|
49 |
+
|
50 |
+
parser.add_argument('--cut-id', action='store_true',
|
51 |
+
help='Whether cut the first token (typically a seq id)')
|
52 |
+
parser.add_argument('--cut-tail', action='store_true',
|
53 |
+
help='Whether cut the last token (typically a speaker id)')
|
54 |
+
parser.add_argument('--debug', action='store_true')
|
55 |
+
|
56 |
+
args = parser.parse_args()
|
57 |
+
|
58 |
+
return args
|
59 |
+
|
60 |
+
|
61 |
+
def get_self_bleu(utterances, averaging_mode, weights):
|
62 |
+
self_bleu = []
|
63 |
+
|
64 |
+
for i in range(len(utterances)):
|
65 |
+
hypo = utterances[i]
|
66 |
+
rest = utterances[:i] + utterances[i+1:]
|
67 |
+
|
68 |
+
self_bleu.append(sentence_bleu(rest, hypo, weights,
|
69 |
+
no_length_penalty=True, averaging_mode=averaging_mode))
|
70 |
+
|
71 |
+
return self_bleu
|
72 |
+
|
73 |
+
|
74 |
+
def get_self_bleu2_arithmetic(utterances):
|
75 |
+
weights = (0.5, 0.5) # equal weight for unigrams and bigrams
|
76 |
+
return get_self_bleu(utterances, averaging_mode='arithmetic', weights=weights)
|
77 |
+
|
78 |
+
|
79 |
+
def get_self_bleu2_geometric(utterances):
|
80 |
+
weights = (0.5, 0.5)
|
81 |
+
return get_self_bleu(utterances, averaging_mode='geometric', weights=weights)
|
82 |
+
|
83 |
+
|
84 |
+
def get_auto_bleu2_arithmetic(utterances):
|
85 |
+
weights = (0.5, 0.5)
|
86 |
+
return [auto_bleu(u, mean_mode='arithmetic', weights=weights) for u in utterances]
|
87 |
+
|
88 |
+
|
89 |
+
def get_auto_bleu2_geometric(utterances):
|
90 |
+
weights = (0.5, 0.5)
|
91 |
+
return [auto_bleu(u, mean_mode='geometric', weights=weights) for u in utterances]
|
92 |
+
|
93 |
+
|
94 |
+
def get_auto_bleu3_geometric(utterances):
|
95 |
+
weights = (1./3, 1./3, 1./3)
|
96 |
+
return [auto_bleu(u, mean_mode='geometric', weights=weights) for u in utterances]
|
97 |
+
|
98 |
+
|
99 |
+
def get_auto_bleu3_arithmetic(utterances):
|
100 |
+
weights = (1./3, 1./3, 1./3)
|
101 |
+
return [auto_bleu(u, mean_mode='arithmetic', weights=weights) for u in utterances]
|
102 |
+
|
103 |
+
|
104 |
+
def get_self_bleu3_arithmetic(utterances):
|
105 |
+
weights = (1./3, 1./3, 1./3)
|
106 |
+
return get_self_bleu(utterances, averaging_mode='arithmetic', weights=weights)
|
107 |
+
|
108 |
+
|
109 |
+
def get_self_bleu3_geometric(utterances):
|
110 |
+
weights = (1./3, 1./3, 1./3)
|
111 |
+
return get_self_bleu(utterances, averaging_mode='geometric', weights=weights)
|
112 |
+
|
113 |
+
|
114 |
+
def auto_bleu(sentence, weights, mean_mode='arithmetic'):
|
115 |
+
if len(sentence) <= 1:
|
116 |
+
return 0
|
117 |
+
|
118 |
+
N = len(weights)
|
119 |
+
|
120 |
+
bleu_n = np.zeros([N])
|
121 |
+
for n in range(N):
|
122 |
+
targ_ngrams = list(nltk.ngrams(sentence, n+1))
|
123 |
+
for p in range(len(targ_ngrams)):
|
124 |
+
left = sentence[:p]
|
125 |
+
right = sentence[(p+n+1):]
|
126 |
+
rest_ngrams = list(nltk.ngrams(left, n+1)) + \
|
127 |
+
list(nltk.ngrams(right, n+1))
|
128 |
+
# compute the nb of matching ngrams
|
129 |
+
bleu_n[n] += targ_ngrams[p] in rest_ngrams
|
130 |
+
bleu_n[n] /= len(targ_ngrams) # average them to get a proportion
|
131 |
+
|
132 |
+
weights = np.array(weights)
|
133 |
+
if mean_mode == 'arithmetic':
|
134 |
+
return (bleu_n * weights).sum()
|
135 |
+
elif mean_mode == 'geometric':
|
136 |
+
return (bleu_n ** weights).prod()
|
137 |
+
else:
|
138 |
+
raise ValueError(f'Unknown agggregation mode {mean_mode}')
|
139 |
+
|
140 |
+
|
141 |
+
def main():
|
142 |
+
from multiprocessing import Pool
|
143 |
+
|
144 |
+
args = get_args()
|
145 |
+
target_ids = get_target_sequences(args.manifest, args.prompts_description)
|
146 |
+
|
147 |
+
with open(args.asr_transcript, 'r') as fin:
|
148 |
+
lines = fin.readlines()
|
149 |
+
|
150 |
+
terms = [x.strip().split() for x in lines]
|
151 |
+
filtered = []
|
152 |
+
for term in terms:
|
153 |
+
line_id = int(term[-1].split('-')[1][:-1])
|
154 |
+
if line_id in target_ids:
|
155 |
+
filtered.append(term)
|
156 |
+
terms = filtered
|
157 |
+
|
158 |
+
if args.cut_id:
|
159 |
+
terms = [x[1:] for x in terms]
|
160 |
+
if args.cut_tail:
|
161 |
+
terms = [x[:-1] for x in terms]
|
162 |
+
|
163 |
+
if args.debug:
|
164 |
+
terms = terms[:10]
|
165 |
+
|
166 |
+
tasks = [
|
167 |
+
('Self-BLEU2-arithmetic', get_self_bleu2_arithmetic),
|
168 |
+
('Self-BLEU2-geometric', get_self_bleu2_geometric),
|
169 |
+
('Auto-BLEU2-arithmetic', get_auto_bleu2_arithmetic),
|
170 |
+
('Auto-BLEU2-geometric', get_auto_bleu2_geometric),
|
171 |
+
|
172 |
+
('Self-BLEU3-arithmetic', get_self_bleu3_arithmetic),
|
173 |
+
('Self-BLEU3-geometric', get_self_bleu3_geometric),
|
174 |
+
('Auto-BLEU3-arithmetic', get_auto_bleu3_arithmetic),
|
175 |
+
('Auto-BLEU3-geometric', get_auto_bleu3_geometric),
|
176 |
+
]
|
177 |
+
|
178 |
+
n_processes = min(16, len(tasks))
|
179 |
+
with Pool(n_processes) as pool:
|
180 |
+
metrics = pool.map(run_f, [(t[1], terms) for t in tasks])
|
181 |
+
|
182 |
+
for (metric_name, _), metric in zip(tasks, metrics):
|
183 |
+
metric, sem = np.mean(metric), np.std(metric) / np.sqrt(len(metric))
|
184 |
+
|
185 |
+
metric, sem = [
|
186 |
+
round(100 * x, 2) for x in [metric, sem]
|
187 |
+
]
|
188 |
+
|
189 |
+
print(f'{metric_name} {metric} +- {sem}')
|
190 |
+
|
191 |
+
|
192 |
+
def run_f(task_params):
|
193 |
+
f, terms = task_params
|
194 |
+
return f(terms)
|
195 |
+
|
196 |
+
|
197 |
+
if __name__ == '__main__':
|
198 |
+
# NLTK produces warnings
|
199 |
+
warnings.filterwarnings("ignore")
|
200 |
+
|
201 |
+
main()
|
fairseq/examples/textless_nlp/gslm/speech2unit/README.md
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Speech to Unit Model (speech2unit)
|
2 |
+
|
3 |
+
## Acoustic Model
|
4 |
+
For quantizing speech we learn a K-means clustering over acoustic representations for which we either use Log-Mel Filterbank or pretrained acoustic representation models. For using pretrained models, please download from their respective locations linked below.
|
5 |
+
* [Modified CPC](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/cpc_big_ll6kh_top_ctc.pt)
|
6 |
+
* [HuBERT-Base](https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt)
|
7 |
+
* [Wav2Vec 2.0-Base](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_new.pt)
|
8 |
+
|
9 |
+
## Quantization Model
|
10 |
+
You can download pretrained quantized model from the list below.
|
11 |
+
|
12 |
+
K-Means Model | Download Link
|
13 |
+
|-|-
|
14 |
+
Log Mel Filterbank + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/km50/km.bin)
|
15 |
+
Log Mel Filterbank + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/km100/km.bin)
|
16 |
+
Log Mel Filterbank + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/km200/km.bin)
|
17 |
+
Modified CPC + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/km50/km.bin)
|
18 |
+
Modified CPC + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/km100/km.bin)
|
19 |
+
Modified CPC + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/km200/km.bin)
|
20 |
+
HuBERT Base + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/km50/km.bin)
|
21 |
+
HuBERT Base + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/km100/km.bin)
|
22 |
+
HuBERT Base + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/km200/km.bin)
|
23 |
+
wav2vec 2.0 Large + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/km50/km.bin)
|
24 |
+
wav2vec 2.0 Large + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/km100/km.bin)
|
25 |
+
wav2vec 2.0 Large + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/km200/km.bin)
|
26 |
+
|
27 |
+
### Quantization
|
28 |
+
For quantizing speech with a given acoustic representation, please follow the steps below.
|
29 |
+
1. Learn K-means clustering model
|
30 |
+
```
|
31 |
+
N_CLUSTERS=<number_of_clusters_used_for_kmeans>
|
32 |
+
TYPE=<one_of_logmel/cpc/hubert/w2v2>
|
33 |
+
CKPT_PATH=<path_of_pretrained_acoustic_model>
|
34 |
+
LAYER=<layer_of_acoustic_model_to_extract_features_from>
|
35 |
+
MANIFEST=<tab_separated_manifest_of_audio_files_for_training_kmeans>
|
36 |
+
KM_MODEL_PATH=<output_path_of_the_kmeans_model>
|
37 |
+
|
38 |
+
PYTHONPATH=. python examples/textless_nlp/gslm/speech2unit/clustering/cluster_kmeans.py \
|
39 |
+
--num_clusters $N_CLUSTERS \
|
40 |
+
--feature_type $TYPE \
|
41 |
+
--checkpoint_path $CKPT_PATH \
|
42 |
+
--layer $LAYER \
|
43 |
+
--manifest_path $MANIFEST \
|
44 |
+
--out_kmeans_model_path $KM_MODEL_PATH
|
45 |
+
```
|
46 |
+
2. Quantize using the learned clusters
|
47 |
+
```
|
48 |
+
MANIFEST=<tab_separated_manifest_of_audio_files_to_quantize>
|
49 |
+
OUT_QUANTIZED_FILE=<output_quantized_audio_file_path>
|
50 |
+
|
51 |
+
python examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py \
|
52 |
+
--feature_type $TYPE \
|
53 |
+
--kmeans_model_path $KM_MODEL_PATH \
|
54 |
+
--acoustic_model_path $CKPT_PATH \
|
55 |
+
--layer $LAYER \
|
56 |
+
--manifest_path $MANIFEST \
|
57 |
+
--out_quantized_file_path $OUT_QUANTIZED_FILE \
|
58 |
+
--extension ".flac"
|
59 |
+
```
|
60 |
+
|
61 |
+
Note about the manifest file is a file with paths and length of input audio files. The format of the file is as follows:
|
62 |
+
```
|
63 |
+
<path_of_root_directory_containing_audio_files>
|
64 |
+
<relative_path_of_audio_file_1>\t<number_of_frames_1>
|
65 |
+
<relative_path_of_audio_file_2>\t<number_of_frames_1>
|
66 |
+
...
|
67 |
+
```
|
68 |
+
|
fairseq/examples/textless_nlp/gslm/speech2unit/clustering/dump_feats.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import logging
|
8 |
+
|
9 |
+
from examples.textless_nlp.gslm.speech2unit.pretrained.utils import (
|
10 |
+
get_and_dump_features,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
def get_parser():
|
15 |
+
parser = argparse.ArgumentParser(
|
16 |
+
description="Compute and dump log mel fbank features."
|
17 |
+
)
|
18 |
+
parser.add_argument(
|
19 |
+
"--feature_type",
|
20 |
+
type=str,
|
21 |
+
choices=["logmel", "hubert", "w2v2", "cpc"],
|
22 |
+
default=None,
|
23 |
+
help="Acoustic feature type",
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--manifest_path",
|
27 |
+
type=str,
|
28 |
+
default=None,
|
29 |
+
help="Manifest file containing the root dir and file names",
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--out_features_path",
|
33 |
+
type=str,
|
34 |
+
default=None,
|
35 |
+
help="Features file path to write to",
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
"--checkpoint_path",
|
39 |
+
type=str,
|
40 |
+
help="Pretrained acoustic model checkpoint",
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--layer",
|
44 |
+
type=int,
|
45 |
+
help="The layer of the pretrained model to extract features from",
|
46 |
+
default=-1,
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--sample_pct",
|
50 |
+
type=float,
|
51 |
+
help="Percent data to use for K-means training",
|
52 |
+
default=0.1,
|
53 |
+
)
|
54 |
+
parser.add_argument(
|
55 |
+
"--out_features_path",
|
56 |
+
type=str,
|
57 |
+
help="Path to save log mel fbank features",
|
58 |
+
)
|
59 |
+
return parser
|
60 |
+
|
61 |
+
|
62 |
+
def get_logger():
|
63 |
+
log_format = "[%(asctime)s] [%(levelname)s]: %(message)s"
|
64 |
+
logging.basicConfig(format=log_format, level=logging.INFO)
|
65 |
+
logger = logging.getLogger(__name__)
|
66 |
+
return logger
|
67 |
+
|
68 |
+
|
69 |
+
if __name__ == "__main__":
|
70 |
+
"""
|
71 |
+
Example command:
|
72 |
+
python ~/speechbot/clustering/dump_logmelfank_feats.py \
|
73 |
+
--manifest_path /checkpoint/kushall/data/LJSpeech-1.1/asr_input_wavs_16k/train.tsv
|
74 |
+
--out_features_path /checkpoint/kushall/experiments/speechbot/logmelfbank/features/ljspeech/train.npy
|
75 |
+
"""
|
76 |
+
parser = get_parser()
|
77 |
+
args = parser.parse_args()
|
78 |
+
logger = get_logger()
|
79 |
+
logger.info(args)
|
80 |
+
|
81 |
+
logger.info(f"Extracting {args.feature_type} acoustic features...")
|
82 |
+
get_and_dump_features(
|
83 |
+
feature_type=args.feature_type,
|
84 |
+
checkpoint_path=args.checkpoint_path,
|
85 |
+
layer=args.layer,
|
86 |
+
manifest_path=args.manifest_path,
|
87 |
+
sample_pct=args.sample_pct,
|
88 |
+
flatten=True,
|
89 |
+
out_features_path=args.out_features_path,
|
90 |
+
)
|
91 |
+
logger.info(f"Saved extracted features at {args.out_features_path}")
|
fairseq/examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
import joblib
|
13 |
+
from examples.textless_nlp.gslm.speech2unit.clustering.utils import (
|
14 |
+
get_audio_files,
|
15 |
+
)
|
16 |
+
from examples.textless_nlp.gslm.speech2unit.pretrained.utils import (
|
17 |
+
get_features,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
def get_logger():
|
22 |
+
log_format = "[%(asctime)s] [%(levelname)s]: %(message)s"
|
23 |
+
logging.basicConfig(format=log_format, level=logging.INFO)
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
return logger
|
26 |
+
|
27 |
+
|
28 |
+
def get_parser():
|
29 |
+
parser = argparse.ArgumentParser(
|
30 |
+
description="Quantize using K-means clustering over acoustic features."
|
31 |
+
)
|
32 |
+
parser.add_argument(
|
33 |
+
"--feature_type",
|
34 |
+
type=str,
|
35 |
+
choices=["logmel", "hubert", "w2v2", "cpc"],
|
36 |
+
default=None,
|
37 |
+
required=True,
|
38 |
+
help="Acoustic feature type",
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--acoustic_model_path",
|
42 |
+
type=str,
|
43 |
+
help="Pretrained acoustic model checkpoint"
|
44 |
+
)
|
45 |
+
parser.add_argument(
|
46 |
+
"--layer",
|
47 |
+
type=int,
|
48 |
+
help="The layer of the pretrained model to extract features from",
|
49 |
+
default=-1,
|
50 |
+
)
|
51 |
+
parser.add_argument(
|
52 |
+
"--kmeans_model_path",
|
53 |
+
type=str,
|
54 |
+
required=True,
|
55 |
+
help="K-means model file path to use for inference",
|
56 |
+
)
|
57 |
+
parser.add_argument(
|
58 |
+
"--features_path",
|
59 |
+
type=str,
|
60 |
+
default=None,
|
61 |
+
help="Features file path. You don't need to enter acoustic model details if you have dumped features",
|
62 |
+
)
|
63 |
+
parser.add_argument(
|
64 |
+
"--manifest_path",
|
65 |
+
type=str,
|
66 |
+
default=None,
|
67 |
+
help="Manifest file containing the root dir and file names",
|
68 |
+
)
|
69 |
+
parser.add_argument(
|
70 |
+
"--out_quantized_file_path",
|
71 |
+
required=True,
|
72 |
+
type=str,
|
73 |
+
help="File path of quantized output.",
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
"--extension", type=str, default=".flac", help="Features file path"
|
77 |
+
)
|
78 |
+
parser.add_argument(
|
79 |
+
"--channel_id",
|
80 |
+
choices=['1', '2'],
|
81 |
+
help="The audio channel to extract the units in case of stereo file.",
|
82 |
+
default=None,
|
83 |
+
)
|
84 |
+
parser.add_argument(
|
85 |
+
"--hide-fname", action='store_true',
|
86 |
+
help="Hide file names in the output file."
|
87 |
+
)
|
88 |
+
return parser
|
89 |
+
|
90 |
+
|
91 |
+
def main(args, logger):
|
92 |
+
# Feature extraction
|
93 |
+
if args.features_path is not None:
|
94 |
+
logger.info(f"Loading acoustic features from {args.features_path}...")
|
95 |
+
features_batch = np.load(args.features_path)
|
96 |
+
else:
|
97 |
+
logger.info(f"Extracting {args.feature_type} acoustic features...")
|
98 |
+
features_batch = get_features(
|
99 |
+
feature_type=args.feature_type,
|
100 |
+
checkpoint_path=args.acoustic_model_path,
|
101 |
+
layer=args.layer,
|
102 |
+
manifest_path=args.manifest_path,
|
103 |
+
sample_pct=1.0,
|
104 |
+
flatten=False,
|
105 |
+
channel_id=int(args.channel_id) if args.channel_id else None,
|
106 |
+
)
|
107 |
+
logger.info(
|
108 |
+
f"Features extracted for {len(features_batch)} utterances.\n"
|
109 |
+
)
|
110 |
+
logger.info(
|
111 |
+
f"Dimensionality of representation = {features_batch[0].shape[1]}"
|
112 |
+
)
|
113 |
+
|
114 |
+
# K-means model
|
115 |
+
logger.info(f"Loading K-means model from {args.kmeans_model_path} ...")
|
116 |
+
kmeans_model = joblib.load(open(args.kmeans_model_path, "rb"))
|
117 |
+
kmeans_model.verbose = False
|
118 |
+
|
119 |
+
_, fnames, _ = get_audio_files(args.manifest_path)
|
120 |
+
|
121 |
+
os.makedirs(os.path.dirname(args.out_quantized_file_path), exist_ok=True)
|
122 |
+
print(f"Writing quantized predictions to {args.out_quantized_file_path}")
|
123 |
+
with open(args.out_quantized_file_path, "w") as fout:
|
124 |
+
for i, feats in enumerate(features_batch):
|
125 |
+
pred = kmeans_model.predict(feats)
|
126 |
+
pred_str = " ".join(str(p) for p in pred)
|
127 |
+
base_fname = os.path.basename(fnames[i]).rstrip('.'+args.extension.lstrip('.'))
|
128 |
+
if args.channel_id is not None:
|
129 |
+
base_fname = base_fname+f'-channel{args.channel_id}'
|
130 |
+
if not args.hide_fname:
|
131 |
+
fout.write(f"{base_fname}|{pred_str}\n")
|
132 |
+
else:
|
133 |
+
fout.write(f"{pred_str}\n")
|
134 |
+
|
135 |
+
|
136 |
+
if __name__ == "__main__":
|
137 |
+
parser = get_parser()
|
138 |
+
args = parser.parse_args()
|
139 |
+
logger = get_logger()
|
140 |
+
logger.info(args)
|
141 |
+
main(args, logger)
|
fairseq/examples/textless_nlp/gslm/speech2unit/clustering/utils.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from typing import List, Tuple
|
7 |
+
|
8 |
+
|
9 |
+
def get_audio_files(manifest_path: str) -> Tuple[str, List[str], List[int]]:
|
10 |
+
fnames, sizes = [], []
|
11 |
+
with open(manifest_path, "r") as f:
|
12 |
+
root_dir = f.readline().strip()
|
13 |
+
for line in f:
|
14 |
+
items = line.strip().split("\t")
|
15 |
+
assert (
|
16 |
+
len(items) == 2
|
17 |
+
), f"File must have two columns separated by tab. Got {line}"
|
18 |
+
fnames.append(items[0])
|
19 |
+
sizes.append(int(items[1]))
|
20 |
+
return root_dir, fnames, sizes
|
fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/cpc_feature_reader.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import soundfile as sf
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class CpcFeatureReader:
|
8 |
+
"""
|
9 |
+
Wrapper class to run inference on CPC model.
|
10 |
+
Helps extract features for a given audio file.
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
checkpoint_path,
|
16 |
+
layer,
|
17 |
+
use_encoder_layer=False,
|
18 |
+
norm_features=False,
|
19 |
+
sample_rate=16000,
|
20 |
+
max_chunk=64000,
|
21 |
+
use_cuda=True,
|
22 |
+
):
|
23 |
+
self.model = load_cpc_model(checkpoint_path, layer).eval()
|
24 |
+
self.sample_rate = sample_rate
|
25 |
+
self.max_chunk = max_chunk
|
26 |
+
self.norm_features = norm_features
|
27 |
+
self.use_encoder_layer = use_encoder_layer
|
28 |
+
self.use_cuda = use_cuda
|
29 |
+
if self.use_cuda:
|
30 |
+
self.model.cuda()
|
31 |
+
|
32 |
+
def read_audio(self, path, ref_len=None, channel_id=None):
|
33 |
+
wav, sr = sf.read(path)
|
34 |
+
if channel_id is not None:
|
35 |
+
assert wav.ndim == 2, \
|
36 |
+
f"Expected stereo input when channel_id is given ({path})"
|
37 |
+
assert channel_id in [1, 2], \
|
38 |
+
"channel_id is expected to be in [1, 2]"
|
39 |
+
wav = wav[:, channel_id-1]
|
40 |
+
if wav.ndim == 2:
|
41 |
+
wav = wav.mean(-1)
|
42 |
+
assert wav.ndim == 1, wav.ndim
|
43 |
+
assert sr == self.sample_rate, sr
|
44 |
+
if ref_len is not None and abs(ref_len - len(wav)) > 160:
|
45 |
+
print(f"ref {ref_len} != read {len(wav)} ({path})")
|
46 |
+
return wav
|
47 |
+
|
48 |
+
def get_feats(self, file_path, ref_len=None, channel_id=None):
|
49 |
+
x = self.read_audio(file_path, ref_len, channel_id)
|
50 |
+
# Inspired from CPC_audio feature_loader.py
|
51 |
+
with torch.no_grad():
|
52 |
+
x = torch.from_numpy(x).float()
|
53 |
+
if self.use_cuda:
|
54 |
+
x = x.cuda()
|
55 |
+
x = x.view(1, 1, -1)
|
56 |
+
size = x.size(2)
|
57 |
+
feat = []
|
58 |
+
start = 0
|
59 |
+
while start < size:
|
60 |
+
if start + self.max_chunk > size:
|
61 |
+
break
|
62 |
+
x_chunk = x[..., start : start + self.max_chunk]
|
63 |
+
feat_chunk = self.model.extract_features(
|
64 |
+
source=x_chunk,
|
65 |
+
get_encoded=self.use_encoder_layer,
|
66 |
+
norm_output=self.norm_features,
|
67 |
+
)
|
68 |
+
feat.append(feat_chunk)
|
69 |
+
start += self.max_chunk
|
70 |
+
|
71 |
+
if start < size:
|
72 |
+
x_chunk = x[:, -self.max_chunk :]
|
73 |
+
feat_chunk = self.model.extract_features(
|
74 |
+
source=x_chunk,
|
75 |
+
get_encoded=self.use_encoder_layer,
|
76 |
+
norm_output=self.norm_features,
|
77 |
+
)
|
78 |
+
df = x_chunk.size(2) // feat_chunk.size(1)
|
79 |
+
delta = (size - start) // df
|
80 |
+
feat.append(feat_chunk[:, -delta:])
|
81 |
+
return torch.cat(feat, 1).squeeze(0)
|
82 |
+
|
83 |
+
|
84 |
+
def load_cpc_model(checkpoint_path, layer=None):
|
85 |
+
state_dict = torch.load(checkpoint_path)
|
86 |
+
weights = state_dict["weights"]
|
87 |
+
config = state_dict["config"]
|
88 |
+
if layer is not None:
|
89 |
+
config["nLevelsGRU"] = layer
|
90 |
+
|
91 |
+
encoder = CPCEncoder(config["hiddenEncoder"])
|
92 |
+
ar_net = CPCAR(
|
93 |
+
config["hiddenEncoder"], config["hiddenGar"], False, config["nLevelsGRU"]
|
94 |
+
)
|
95 |
+
|
96 |
+
model = CPCModel(encoder, ar_net)
|
97 |
+
model.load_state_dict(weights, strict=False)
|
98 |
+
model.config = config
|
99 |
+
|
100 |
+
return model
|
101 |
+
|
102 |
+
|
103 |
+
class ChannelNorm(nn.Module):
|
104 |
+
def __init__(self, num_features, epsilon=1e-05, affine=True):
|
105 |
+
super(ChannelNorm, self).__init__()
|
106 |
+
if affine:
|
107 |
+
self.weight = nn.parameter.Parameter(torch.Tensor(1, num_features, 1))
|
108 |
+
self.bias = nn.parameter.Parameter(torch.Tensor(1, num_features, 1))
|
109 |
+
else:
|
110 |
+
self.weight = None
|
111 |
+
self.bias = None
|
112 |
+
self.epsilon = epsilon
|
113 |
+
self.p = 0
|
114 |
+
self.affine = affine
|
115 |
+
self.reset_parameters()
|
116 |
+
|
117 |
+
def reset_parameters(self):
|
118 |
+
if self.affine:
|
119 |
+
torch.nn.init.ones_(self.weight)
|
120 |
+
torch.nn.init.zeros_(self.bias)
|
121 |
+
|
122 |
+
def forward(self, x):
|
123 |
+
cum_mean = x.mean(dim=1, keepdim=True)
|
124 |
+
cum_var = x.var(dim=1, keepdim=True)
|
125 |
+
x = (x - cum_mean) * torch.rsqrt(cum_var + self.epsilon)
|
126 |
+
if self.weight is not None:
|
127 |
+
x = x * self.weight + self.bias
|
128 |
+
return x
|
129 |
+
|
130 |
+
|
131 |
+
class CPCEncoder(nn.Module):
|
132 |
+
def __init__(self, hidden_dim=512):
|
133 |
+
super(CPCEncoder, self).__init__()
|
134 |
+
self.conv0 = nn.Conv1d(1, hidden_dim, 10, stride=5, padding=3)
|
135 |
+
self.batchNorm0 = ChannelNorm(hidden_dim)
|
136 |
+
self.conv1 = nn.Conv1d(hidden_dim, hidden_dim, 8, stride=4, padding=2)
|
137 |
+
self.batchNorm1 = ChannelNorm(hidden_dim)
|
138 |
+
self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, 4, stride=2, padding=1)
|
139 |
+
self.batchNorm2 = ChannelNorm(hidden_dim)
|
140 |
+
self.conv3 = nn.Conv1d(hidden_dim, hidden_dim, 4, stride=2, padding=1)
|
141 |
+
self.batchNorm3 = ChannelNorm(hidden_dim)
|
142 |
+
self.conv4 = nn.Conv1d(hidden_dim, hidden_dim, 4, stride=2, padding=1)
|
143 |
+
self.batchNorm4 = ChannelNorm(hidden_dim)
|
144 |
+
self.DOWNSAMPLING = 160
|
145 |
+
|
146 |
+
def get_output_dim(self):
|
147 |
+
return self.conv4.out_channels
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
x = F.relu(self.batchNorm0(self.conv0(x)))
|
151 |
+
x = F.relu(self.batchNorm1(self.conv1(x)))
|
152 |
+
x = F.relu(self.batchNorm2(self.conv2(x)))
|
153 |
+
x = F.relu(self.batchNorm3(self.conv3(x)))
|
154 |
+
x = F.relu(self.batchNorm4(self.conv4(x)))
|
155 |
+
return x
|
156 |
+
|
157 |
+
|
158 |
+
class CPCAR(nn.Module):
|
159 |
+
def __init__(self, dim_encoded, dim_output, keep_hidden, num_layers):
|
160 |
+
super(CPCAR, self).__init__()
|
161 |
+
self.baseNet = nn.LSTM(
|
162 |
+
dim_encoded, dim_output, num_layers=num_layers, batch_first=True
|
163 |
+
)
|
164 |
+
self.hidden = None
|
165 |
+
self.keep_hidden = keep_hidden
|
166 |
+
|
167 |
+
def get_output_dim(self):
|
168 |
+
return self.baseNet.hidden_size
|
169 |
+
|
170 |
+
def forward(self, x):
|
171 |
+
try:
|
172 |
+
self.baseNet.flatten_parameters()
|
173 |
+
except RuntimeError:
|
174 |
+
pass
|
175 |
+
x, h = self.baseNet(x, self.hidden)
|
176 |
+
if self.keep_hidden:
|
177 |
+
if isinstance(h, tuple):
|
178 |
+
self.hidden = tuple(x.detach() for x in h)
|
179 |
+
else:
|
180 |
+
self.hidden = h.detach()
|
181 |
+
return x
|
182 |
+
|
183 |
+
|
184 |
+
class CPCModel(nn.Module):
|
185 |
+
def __init__(self, encoder, ar_net):
|
186 |
+
super(CPCModel, self).__init__()
|
187 |
+
self.gEncoder = encoder
|
188 |
+
self.gAR = ar_net
|
189 |
+
self.config = None
|
190 |
+
|
191 |
+
def forward(self, x, label):
|
192 |
+
encoded = self.gEncoder(x).permute(0, 2, 1)
|
193 |
+
cpc_feature = self.gAR(encoded)
|
194 |
+
return cpc_feature, encoded, label
|
195 |
+
|
196 |
+
def extract_features(self, source, get_encoded=False, norm_output=False):
|
197 |
+
cpc_feature, encoded, _ = self.forward(source, None)
|
198 |
+
if get_encoded:
|
199 |
+
cpc_feature = encoded
|
200 |
+
if norm_output:
|
201 |
+
mean = cpc_feature.mean(dim=1, keepdim=True)
|
202 |
+
var = cpc_feature.var(dim=1, keepdim=True)
|
203 |
+
cpc_feature = (cpc_feature - mean) / torch.sqrt(var + 1e-08)
|
204 |
+
return cpc_feature
|
fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/hubert_feature_reader.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import fairseq
|
8 |
+
import soundfile as sf
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
class HubertFeatureReader:
|
13 |
+
"""
|
14 |
+
Wrapper class to run inference on HuBERT model.
|
15 |
+
Helps extract features for a given audio file.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, checkpoint_path, layer, max_chunk=1600000, use_cuda=True):
|
19 |
+
(
|
20 |
+
model,
|
21 |
+
cfg,
|
22 |
+
task,
|
23 |
+
) = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
24 |
+
[checkpoint_path]
|
25 |
+
)
|
26 |
+
self.model = model[0].eval()
|
27 |
+
self.task = task
|
28 |
+
self.layer = layer
|
29 |
+
self.max_chunk = max_chunk
|
30 |
+
self.use_cuda = use_cuda
|
31 |
+
if self.use_cuda:
|
32 |
+
self.model.cuda()
|
33 |
+
|
34 |
+
def read_audio(self, path, ref_len=None, channel_id=None):
|
35 |
+
wav, sr = sf.read(path)
|
36 |
+
if channel_id is not None:
|
37 |
+
assert wav.ndim == 2, \
|
38 |
+
f"Expected stereo input when channel_id is given ({path})"
|
39 |
+
assert channel_id in [1, 2], \
|
40 |
+
"channel_id is expected to be in [1, 2]"
|
41 |
+
wav = wav[:, channel_id-1]
|
42 |
+
if wav.ndim == 2:
|
43 |
+
wav = wav.mean(-1)
|
44 |
+
assert wav.ndim == 1, wav.ndim
|
45 |
+
assert sr == self.task.cfg.sample_rate, sr
|
46 |
+
if ref_len is not None and abs(ref_len - len(wav)) > 160:
|
47 |
+
print(f"ref {ref_len} != read {len(wav)} ({path})")
|
48 |
+
return wav
|
49 |
+
|
50 |
+
def get_feats(self, file_path, ref_len=None, channel_id=None):
|
51 |
+
x = self.read_audio(file_path, ref_len, channel_id)
|
52 |
+
with torch.no_grad():
|
53 |
+
x = torch.from_numpy(x).float()
|
54 |
+
if self.use_cuda:
|
55 |
+
x = x.cuda()
|
56 |
+
if self.task.cfg.normalize:
|
57 |
+
x = F.layer_norm(x, x.shape)
|
58 |
+
x = x.view(1, -1)
|
59 |
+
|
60 |
+
feat = []
|
61 |
+
for start in range(0, x.size(1), self.max_chunk):
|
62 |
+
x_chunk = x[:, start: start + self.max_chunk]
|
63 |
+
feat_chunk, _ = self.model.extract_features(
|
64 |
+
source=x_chunk,
|
65 |
+
padding_mask=None,
|
66 |
+
mask=False,
|
67 |
+
output_layer=self.layer,
|
68 |
+
)
|
69 |
+
feat.append(feat_chunk)
|
70 |
+
return torch.cat(feat, 1).squeeze(0)
|
fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/logmel_feature_reader.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import soundfile as sf
|
7 |
+
import torch
|
8 |
+
import torchaudio.compliance.kaldi as kaldi
|
9 |
+
|
10 |
+
|
11 |
+
class LogMelFeatureReader:
|
12 |
+
"""
|
13 |
+
Wrapper class to run inference on HuBERT model.
|
14 |
+
Helps extract features for a given audio file.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, *args, **kwargs):
|
18 |
+
self.num_mel_bins = kwargs.get("num_mel_bins", 80)
|
19 |
+
self.frame_length = kwargs.get("frame_length", 25.0)
|
20 |
+
|
21 |
+
def get_feats(self, file_path, channel_id=None):
|
22 |
+
wav, sr = sf.read(file_path)
|
23 |
+
if channel_id is not None:
|
24 |
+
assert wav.ndim == 2, \
|
25 |
+
f"Expected stereo input when channel_id is given ({file_path})"
|
26 |
+
wav = wav[:, channel_id-1]
|
27 |
+
feats = torch.from_numpy(wav).float()
|
28 |
+
feats = kaldi.fbank(
|
29 |
+
feats.unsqueeze(0),
|
30 |
+
num_mel_bins=self.num_mel_bins,
|
31 |
+
frame_length=self.frame_length,
|
32 |
+
sample_frequency=sr,
|
33 |
+
)
|
34 |
+
return feats
|
fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/utils.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import gc
|
7 |
+
import os
|
8 |
+
import random
|
9 |
+
import shutil
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import tqdm
|
14 |
+
from examples.textless_nlp.gslm.speech2unit.pretrained.cpc_feature_reader import (
|
15 |
+
CpcFeatureReader,
|
16 |
+
)
|
17 |
+
from examples.textless_nlp.gslm.speech2unit.pretrained.hubert_feature_reader import (
|
18 |
+
HubertFeatureReader,
|
19 |
+
)
|
20 |
+
from examples.textless_nlp.gslm.speech2unit.pretrained.logmel_feature_reader import (
|
21 |
+
LogMelFeatureReader,
|
22 |
+
)
|
23 |
+
from examples.textless_nlp.gslm.speech2unit.pretrained.w2v2_feature_reader import (
|
24 |
+
Wav2VecFeatureReader,
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
def get_feature_reader(feature_type):
|
29 |
+
if feature_type == "logmel":
|
30 |
+
return LogMelFeatureReader
|
31 |
+
elif feature_type == "hubert":
|
32 |
+
return HubertFeatureReader
|
33 |
+
elif feature_type == "w2v2":
|
34 |
+
return Wav2VecFeatureReader
|
35 |
+
elif feature_type == "cpc":
|
36 |
+
return CpcFeatureReader
|
37 |
+
else:
|
38 |
+
raise NotImplementedError(f"{feature_type} is not supported.")
|
39 |
+
|
40 |
+
|
41 |
+
def get_feature_iterator(
|
42 |
+
feature_type, checkpoint_path, layer, manifest_path, sample_pct, channel_id
|
43 |
+
):
|
44 |
+
feature_reader_cls = get_feature_reader(feature_type)
|
45 |
+
with open(manifest_path, "r") as fp:
|
46 |
+
lines = fp.read().split("\n")
|
47 |
+
root = lines.pop(0).strip()
|
48 |
+
file_path_list = [
|
49 |
+
os.path.join(root, line.split("\t")[0])
|
50 |
+
for line in lines
|
51 |
+
if len(line) > 0
|
52 |
+
]
|
53 |
+
if sample_pct < 1.0:
|
54 |
+
file_path_list = random.sample(
|
55 |
+
file_path_list, int(sample_pct * len(file_path_list))
|
56 |
+
)
|
57 |
+
num_files = len(file_path_list)
|
58 |
+
reader = feature_reader_cls(
|
59 |
+
checkpoint_path=checkpoint_path, layer=layer
|
60 |
+
)
|
61 |
+
|
62 |
+
def iterate():
|
63 |
+
for file_path in file_path_list:
|
64 |
+
feats = reader.get_feats(file_path, channel_id=channel_id)
|
65 |
+
yield feats.cpu().numpy()
|
66 |
+
|
67 |
+
return iterate, num_files
|
68 |
+
|
69 |
+
|
70 |
+
def get_features(
|
71 |
+
feature_type, checkpoint_path, layer, manifest_path, sample_pct, flatten, channel_id
|
72 |
+
):
|
73 |
+
generator, num_files = get_feature_iterator(
|
74 |
+
feature_type=feature_type,
|
75 |
+
checkpoint_path=checkpoint_path,
|
76 |
+
layer=layer,
|
77 |
+
manifest_path=manifest_path,
|
78 |
+
sample_pct=sample_pct,
|
79 |
+
channel_id=channel_id
|
80 |
+
)
|
81 |
+
iterator = generator()
|
82 |
+
|
83 |
+
features_list = []
|
84 |
+
for features in tqdm.tqdm(iterator, total=num_files):
|
85 |
+
features_list.append(features)
|
86 |
+
|
87 |
+
# Explicit clean up
|
88 |
+
del iterator
|
89 |
+
del generator
|
90 |
+
gc.collect()
|
91 |
+
torch.cuda.empty_cache()
|
92 |
+
|
93 |
+
if flatten:
|
94 |
+
return np.concatenate(features_list)
|
95 |
+
|
96 |
+
return features_list
|
97 |
+
|
98 |
+
|
99 |
+
def get_and_dump_features(
|
100 |
+
feature_type,
|
101 |
+
checkpoint_path,
|
102 |
+
layer,
|
103 |
+
manifest_path,
|
104 |
+
sample_pct,
|
105 |
+
flatten,
|
106 |
+
out_features_path,
|
107 |
+
):
|
108 |
+
# Feature extraction
|
109 |
+
features_batch = get_features(
|
110 |
+
feature_type=feature_type,
|
111 |
+
checkpoint_path=checkpoint_path,
|
112 |
+
layer=layer,
|
113 |
+
manifest_path=manifest_path,
|
114 |
+
sample_pct=sample_pct,
|
115 |
+
flatten=flatten,
|
116 |
+
)
|
117 |
+
|
118 |
+
# Save features
|
119 |
+
out_dir_path = os.path.dirname(out_features_path)
|
120 |
+
os.makedirs(out_dir_path, exist_ok=True)
|
121 |
+
shutil.copyfile(
|
122 |
+
manifest_path,
|
123 |
+
os.path.join(out_dir_path, os.path.basename(manifest_path)),
|
124 |
+
)
|
125 |
+
np.save(out_features_path, features_batch)
|
126 |
+
|
127 |
+
return features_batch
|
fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/w2v2_feature_reader.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import fairseq
|
8 |
+
import soundfile as sf
|
9 |
+
|
10 |
+
|
11 |
+
class Wav2VecFeatureReader:
|
12 |
+
"""
|
13 |
+
Wrapper class to run inference on Wav2Vec 2.0 model.
|
14 |
+
Helps extract features for a given audio file.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, checkpoint_path, layer, use_cuda=True):
|
18 |
+
state = fairseq.checkpoint_utils.load_checkpoint_to_cpu(
|
19 |
+
checkpoint_path
|
20 |
+
)
|
21 |
+
|
22 |
+
w2v_args = state["args"]
|
23 |
+
self.task = fairseq.tasks.setup_task(w2v_args)
|
24 |
+
model = self.task.build_model(w2v_args)
|
25 |
+
model.load_state_dict(state["model"], strict=True)
|
26 |
+
model.eval()
|
27 |
+
self.model = model
|
28 |
+
self.layer = layer
|
29 |
+
self.use_cuda = use_cuda
|
30 |
+
if self.use_cuda:
|
31 |
+
self.model.cuda()
|
32 |
+
|
33 |
+
def read_audio(self, fname, channel_id=None):
|
34 |
+
wav, sr = sf.read(fname)
|
35 |
+
if channel_id is not None:
|
36 |
+
assert wav.ndim == 2, \
|
37 |
+
f"Expected stereo input when channel_id is given ({fname})"
|
38 |
+
assert channel_id in [1, 2], \
|
39 |
+
"channel_id is expected to be in [1, 2]"
|
40 |
+
wav = wav[:, channel_id-1]
|
41 |
+
if wav.ndim == 2:
|
42 |
+
wav = wav.mean(-1)
|
43 |
+
assert wav.ndim == 1, wav.ndim
|
44 |
+
assert sr == self.task.cfg.sample_rate, sr
|
45 |
+
return wav
|
46 |
+
|
47 |
+
def get_feats(self, file_path, channel_id=None):
|
48 |
+
x = self.read_audio(file_path, channel_id)
|
49 |
+
with torch.no_grad():
|
50 |
+
source = torch.from_numpy(x).view(1, -1).float()
|
51 |
+
if self.use_cuda:
|
52 |
+
source = source.cuda()
|
53 |
+
res = self.model(
|
54 |
+
source=source, mask=False, features_only=True, layer=self.layer
|
55 |
+
)
|
56 |
+
return res["layer_results"][self.layer][0].squeeze(1)
|
fairseq/examples/textless_nlp/gslm/tools/README.md
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GSLM Tools
|
2 |
+
|
3 |
+
## Resynthesis
|
4 |
+
You can use the command line tool below to input an audio file and get the resynthesized audio. This tool implements the unsupervised method for resynthesis described in the paper. The way to invoke the command line tool is shown below.
|
5 |
+
```
|
6 |
+
FAIRSEQ_ROOT=<path_to_your_fairseq_repo_root>
|
7 |
+
TYPE=<one_of_logmel/cpc/hubert/w2v2>
|
8 |
+
ACOUSTIC_MODEL_PATH=<path_of_pretrained_acoustic_model>
|
9 |
+
LAYER=<layer_of_acoustic_model_to_extract_features_from>
|
10 |
+
KM_MODEL_PATH=<output_path_of_the_kmeans_model>
|
11 |
+
TTS_MODEL_PATH=<unit2speech_model_file_path>
|
12 |
+
# A text file containing the codes, one per line
|
13 |
+
CODE_DICT_PATH=<unit2speech_code_dict_path>
|
14 |
+
WAVEGLOW_PATH=<path_where_you_have_downloaded_waveglow_checkpoint>
|
15 |
+
|
16 |
+
PYTHONPATH=${FAIRSEQ_ROOT}:${FAIRSEQ_ROOT}/examples/textless_nlp/gslm/unit2speech python ${FAIRSEQ_ROOT}/examples/textless_nlp/gslm/tools/resynthesize_speech.py \
|
17 |
+
--feature_type $TYPE \
|
18 |
+
--acoustic_model_path $ACOUSTIC_MODEL_PATH \
|
19 |
+
--layer $LAYER \
|
20 |
+
--kmeans_model_path $KM_MODEL_PATH \
|
21 |
+
--tts_model_path $TTS_MODEL_PATH \
|
22 |
+
--code_dict_path $CODE_DICT_PATH \
|
23 |
+
--waveglow_path $WAVEGLOW_PATH \
|
24 |
+
--max_decoder_steps 2000
|
25 |
+
```
|
fairseq/examples/textless_nlp/gslm/tools/resynthesize_speech.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import gc
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
|
11 |
+
import joblib
|
12 |
+
import soundfile as sf
|
13 |
+
import torch
|
14 |
+
from examples.textless_nlp.gslm.speech2unit.pretrained.utils import get_feature_reader
|
15 |
+
from examples.textless_nlp.gslm.unit2speech.tts_data import TacotronInputDataset
|
16 |
+
from examples.textless_nlp.gslm.unit2speech.utils import (
|
17 |
+
load_tacotron,
|
18 |
+
load_waveglow,
|
19 |
+
synthesize_audio,
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
def get_logger():
|
24 |
+
log_format = "[%(asctime)s] [%(levelname)s]: %(message)s"
|
25 |
+
logging.basicConfig(format=log_format, level=logging.INFO)
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
return logger
|
28 |
+
|
29 |
+
|
30 |
+
def get_parser():
|
31 |
+
parser = argparse.ArgumentParser(description="GSLM U2S tool")
|
32 |
+
parser.add_argument(
|
33 |
+
"--feature_type",
|
34 |
+
type=str,
|
35 |
+
choices=["logmel", "hubert", "w2v2", "cpc"],
|
36 |
+
default=None,
|
37 |
+
required=True,
|
38 |
+
help="Acoustic feature type",
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--acoustic_model_path",
|
42 |
+
type=str,
|
43 |
+
help="Pretrained acoustic model checkpoint",
|
44 |
+
)
|
45 |
+
parser.add_argument("--layer", type=int, help="Layer of acoustic model")
|
46 |
+
parser.add_argument(
|
47 |
+
"--kmeans_model_path",
|
48 |
+
type=str,
|
49 |
+
required=True,
|
50 |
+
help="K-means model file path to use for inference",
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
"--tts_model_path",
|
54 |
+
type=str,
|
55 |
+
help="TTS model file path to use for inference",
|
56 |
+
)
|
57 |
+
parser.add_argument(
|
58 |
+
"--code_dict_path",
|
59 |
+
type=str,
|
60 |
+
help="Code dict file path to use for inference",
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--waveglow_path",
|
64 |
+
type=str,
|
65 |
+
help="Waveglow (vocoder) model file path to use for inference",
|
66 |
+
)
|
67 |
+
parser.add_argument("--max_decoder_steps", type=int, default=2000)
|
68 |
+
parser.add_argument("--denoiser_strength", type=float, default=0.1)
|
69 |
+
return parser
|
70 |
+
|
71 |
+
|
72 |
+
################################################
|
73 |
+
def main(args, logger):
|
74 |
+
# Acoustic Model
|
75 |
+
logger.info(f"Loading acoustic model from {args.tts_model_path}...")
|
76 |
+
feature_reader_cls = get_feature_reader(args.feature_type)
|
77 |
+
reader = feature_reader_cls(
|
78 |
+
checkpoint_path=args.acoustic_model_path, layer=args.layer
|
79 |
+
)
|
80 |
+
|
81 |
+
# K-means Model
|
82 |
+
logger.info(f"Loading K-means model from {args.kmeans_model_path} ...")
|
83 |
+
kmeans_model = joblib.load(open(args.kmeans_model_path, "rb"))
|
84 |
+
kmeans_model.verbose = False
|
85 |
+
|
86 |
+
# TTS Model
|
87 |
+
logger.info(f"Loading TTS model from {args.tts_model_path}...")
|
88 |
+
tacotron_model, sample_rate, hparams = load_tacotron(
|
89 |
+
tacotron_model_path=args.tts_model_path,
|
90 |
+
max_decoder_steps=args.max_decoder_steps,
|
91 |
+
)
|
92 |
+
|
93 |
+
# Waveglow Model
|
94 |
+
logger.info(f"Loading Waveglow model from {args.waveglow_path}...")
|
95 |
+
waveglow, denoiser = load_waveglow(waveglow_path=args.waveglow_path)
|
96 |
+
|
97 |
+
# Dataset
|
98 |
+
if not os.path.exists(hparams.code_dict):
|
99 |
+
hparams.code_dict = args.code_dict_path
|
100 |
+
tts_dataset = TacotronInputDataset(hparams)
|
101 |
+
|
102 |
+
iters = 0
|
103 |
+
while True:
|
104 |
+
in_file_path = input("Input: Enter the full file path of audio file...\n")
|
105 |
+
out_file_path = input("Output: Enter the full file path of audio file...\n")
|
106 |
+
feats = reader.get_feats(in_file_path).cpu().numpy()
|
107 |
+
iters += 1
|
108 |
+
if iters == 1000:
|
109 |
+
gc.collect()
|
110 |
+
torch.cuda.empty_cache()
|
111 |
+
|
112 |
+
quantized_units = kmeans_model.predict(feats)
|
113 |
+
quantized_units_str = " ".join(map(str, quantized_units))
|
114 |
+
|
115 |
+
tts_input = tts_dataset.get_tensor(quantized_units_str)
|
116 |
+
mel, aud, aud_dn, has_eos = synthesize_audio(
|
117 |
+
tacotron_model,
|
118 |
+
waveglow,
|
119 |
+
denoiser,
|
120 |
+
tts_input.unsqueeze(0),
|
121 |
+
strength=args.denoiser_strength,
|
122 |
+
)
|
123 |
+
sf.write(f"{out_file_path}", aud_dn[0].cpu().float().numpy(), sample_rate)
|
124 |
+
logger.info("Resynthesis done!\n")
|
125 |
+
|
126 |
+
|
127 |
+
if __name__ == "__main__":
|
128 |
+
parser = get_parser()
|
129 |
+
args = parser.parse_args()
|
130 |
+
logger = get_logger()
|
131 |
+
logger.info(args)
|
132 |
+
main(args, logger)
|
fairseq/examples/textless_nlp/gslm/ulm/README.md
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Unit Language Model (ULM)
|
2 |
+
|
3 |
+
Here you can find links to the pre-trained ULMs and instructions on training new models using fairseq. At the end of the page, we also share how to run sampling for those models and provide pointers to the transcribed prompts we used.
|
4 |
+
|
5 |
+
## Pre-trained models
|
6 |
+
|
7 |
+
Using the links below, you can download pre-trained models for various unit types and vocabulary sizes:
|
8 |
+
|
9 |
+
| | 50 | 100 | 200
|
10 |
+
|-|-|-|-
|
11 |
+
| LogMel Filterbank | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/lm_km50/logmel50_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/lm_km100/logmel100_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/lm_km200/logmel200_lm.tgz)
|
12 |
+
| Modified CPC | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/lm_km50/cpc50_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/lm_km100/cpc100_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/lm_km200/cpc200_lm.tgz)
|
13 |
+
| HuBERT | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/lm_km50/hubert50_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/lm_km100/hubert100_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/lm_km200/hubert200_lm.tgz)
|
14 |
+
| Wav2Vec 2.0 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/lm_km50/w2v2_50_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/lm_km100/w2v2_100_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/lm_km200/w2v2_200_lm.tgz)
|
15 |
+
|
16 |
+
|
17 |
+
## Preprocessing data
|
18 |
+
Assuming that unit-transcribed train, valid, and test sets are located in `data/train.txt`, `data/valid.txt`, and `data/test.txt`, respectively,
|
19 |
+
we run the following command to get a preprocessed version of the datast in `data-bin`:
|
20 |
+
|
21 |
+
```bash
|
22 |
+
fairseq-preprocess --only-source \
|
23 |
+
--trainpref data/train.txt --validpref data/valid.txt --testpref data/test.txt \
|
24 |
+
--destdir data-bin/ --workers 40
|
25 |
+
```
|
26 |
+
As a result, the `data-bin` directory should appear.
|
27 |
+
|
28 |
+
## Fitting a Unit Language Model (ULM)
|
29 |
+
As an ULM, we train a standard fairseq Transformer LM. Assuming 8 GPUs used for training, a good starting point for an ULM training would be:
|
30 |
+
```bash
|
31 |
+
fairseq-train data-bin/ \
|
32 |
+
--task=language_modeling \
|
33 |
+
--arch=transformer_lm_big \
|
34 |
+
--share-decoder-input-output-embed \
|
35 |
+
--dropout=0.1 \
|
36 |
+
--attention-dropout=0.1 \
|
37 |
+
--optimizer=adam \
|
38 |
+
--adam-betas='(0.9, 0.98)' \
|
39 |
+
--clip-norm=1.0 \
|
40 |
+
--lr=0.0005 \
|
41 |
+
--lr-scheduler=inverse_sqrt \
|
42 |
+
--warmup-updates=4000 \
|
43 |
+
--warmup-init-lr=1e-07 \
|
44 |
+
--tokens-per-sample=3072 \
|
45 |
+
--update-freq=16 \
|
46 |
+
--max-tokens=4096 \
|
47 |
+
--num-workers=4 \
|
48 |
+
--skip-invalid-size-inputs-valid-test \
|
49 |
+
--max-update=500000 \
|
50 |
+
--log-interval=10 \
|
51 |
+
--seed=100501 \
|
52 |
+
--fp16 \
|
53 |
+
--sample-break-mode=eos
|
54 |
+
```
|
55 |
+
This command will train a Transformer-large model (12 layers). You can train other standard LM models provided by fairseq, e.g. specify `--arch=transformer_lm` to train a smaller (6-layer) Transformer model. When training with a different number of GPUs, it might be a good idea to adjust the `update-freq` parameter. To save the GPU memory at an expense of additional computation, it can be useful to enable activation checkpointing with `--checkpoint-activations`.
|
56 |
+
|
57 |
+
## Sampling from an ULM
|
58 |
+
Once an ULM was trained, we can use it for generating new utterances. Suppose, that the prompts are given in a file named `prompts.txt`. Then we can sample continuations by running the following command:
|
59 |
+
|
60 |
+
```bash
|
61 |
+
python sample.py data-bin/ \
|
62 |
+
--path=checkpoints/checkpoint_best.pt --task=language_modeling --sampling --temperature=0.7 \
|
63 |
+
--seed=1 --prompts=prompts.txt --output=samples.txt --max-len-a=0 --max-len-b=500 \
|
64 |
+
--prefix-size=-1 --batch-size=16 --fp16 --samples-per-prompt=10
|
65 |
+
```
|
66 |
+
Here, `--prefix-size` controls the number of tokens that are used to prime the ULM. When set to a positive value, the sampling script will take first `prefix-size` tokens to prompt the ULM; with `0` it runs unconditional sampling and with `-1` the entire prompt is used.
|
67 |
+
`--samples-per-prompt` specifies how many utterances are generated with every prompt which can be useful when generating multiple prompt continuations. In this command, `--max-len-a` and `--max-len-b` control the number of generated tokens.
|
68 |
+
|
69 |
+
When using a pretrained model from above, `data-bin` should point to the unpacked directory (with `dict.txt` file).
|
70 |
+
|
71 |
+
Evaluation-time, to generate prompts, we used utterances from LibriSpeech dev-clean and test-clean that are longer than 6s. We took first 3s from an utterance as a prompt. Unit transcripts of those prompts can be downloaded here: [[dev]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/dev_prompts.tgz) [[test]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/test_prompts.tgz)
|
72 |
+
|
fairseq/examples/textless_nlp/gslm/ulm/sample.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3 -u
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
Sample from a trained LM; hacked fairseq-interactive
|
8 |
+
"""
|
9 |
+
from collections import namedtuple
|
10 |
+
import os
|
11 |
+
import ast
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
from fairseq import checkpoint_utils, options, tasks, utils
|
15 |
+
|
16 |
+
import tqdm
|
17 |
+
|
18 |
+
Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
|
19 |
+
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
|
20 |
+
|
21 |
+
|
22 |
+
def make_batches(lines, args, task, max_positions):
|
23 |
+
tokens = [
|
24 |
+
task.source_dictionary.encode_line(
|
25 |
+
src_str, add_if_not_exist=False
|
26 |
+
).long()
|
27 |
+
for src_str in lines
|
28 |
+
]
|
29 |
+
lengths = [t.numel() for t in tokens]
|
30 |
+
itr = task.get_batch_iterator(
|
31 |
+
dataset=task.build_dataset_for_inference(tokens, lengths),
|
32 |
+
max_tokens=args.dataset.max_tokens,
|
33 |
+
max_sentences=args.dataset.batch_size,
|
34 |
+
max_positions=max_positions,
|
35 |
+
ignore_invalid_inputs=args.dataset.skip_invalid_size_inputs_valid_test
|
36 |
+
).next_epoch_itr(shuffle=False)
|
37 |
+
for batch in itr:
|
38 |
+
yield Batch(
|
39 |
+
ids=batch['id'],
|
40 |
+
src_tokens=batch['net_input']['src_tokens'], src_lengths=batch['net_input']['src_lengths'],
|
41 |
+
)
|
42 |
+
|
43 |
+
|
44 |
+
def main(args):
|
45 |
+
arg_prompts = args.prompts
|
46 |
+
arg_output = args.output
|
47 |
+
arg_debug = args.debug
|
48 |
+
arg_sample_size = args.samples_per_prompt
|
49 |
+
|
50 |
+
try:
|
51 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
52 |
+
args = convert_namespace_to_omegaconf(args)
|
53 |
+
except:
|
54 |
+
pass
|
55 |
+
|
56 |
+
# if args.max_tokens is None and args.max_sentences is None:
|
57 |
+
if args.common.seed is not None:
|
58 |
+
np.random.seed(args.common.seed)
|
59 |
+
utils.set_torch_seed(args.common.seed)
|
60 |
+
|
61 |
+
if args.generation.sampling:
|
62 |
+
args.generation.nbest = args.generation.beam = arg_sample_size
|
63 |
+
|
64 |
+
task = tasks.setup_task(args.task)
|
65 |
+
|
66 |
+
overrides = ast.literal_eval(args.common_eval.model_overrides)
|
67 |
+
|
68 |
+
models, _model_args = checkpoint_utils.load_model_ensemble(
|
69 |
+
args.common_eval.path.split(os.pathsep),
|
70 |
+
arg_overrides=overrides,
|
71 |
+
task=task,
|
72 |
+
suffix=getattr(args, "checkpoint_suffix", ""),
|
73 |
+
)
|
74 |
+
|
75 |
+
# Set dictionaries
|
76 |
+
src_dict = task.source_dictionary
|
77 |
+
tgt_dict = task.target_dictionary
|
78 |
+
|
79 |
+
# Optimize ensemble for generation
|
80 |
+
for model in models:
|
81 |
+
model.prepare_for_inference_(args)
|
82 |
+
model.cuda()
|
83 |
+
|
84 |
+
# Load alignment dictionary for unknown word replacement
|
85 |
+
# (None if no unknown word replacement, empty if no path to align dictionary)
|
86 |
+
align_dict = utils.load_align_dict(args.generation.replace_unk)
|
87 |
+
|
88 |
+
max_positions = utils.resolve_max_positions(
|
89 |
+
task.max_positions(),
|
90 |
+
*[model.max_positions() for model in models]
|
91 |
+
)
|
92 |
+
|
93 |
+
output_file = open(arg_output, 'w')
|
94 |
+
|
95 |
+
with open(arg_prompts, 'r') as fin:
|
96 |
+
lines = fin.readlines()
|
97 |
+
|
98 |
+
split = [x.split('|', 1) for x in lines]
|
99 |
+
seq_id = [x[0] for x in split]
|
100 |
+
prompts = [x[1] for x in split]
|
101 |
+
|
102 |
+
if args.generation.prefix_size >= 0:
|
103 |
+
prompts = [' '.join(l.split()[:args.generation.prefix_size])
|
104 |
+
for l in prompts]
|
105 |
+
|
106 |
+
if arg_debug:
|
107 |
+
prompts = prompts[:10]
|
108 |
+
|
109 |
+
generator = task.build_generator(models, args.generation)
|
110 |
+
|
111 |
+
start_id = 0
|
112 |
+
pbar = tqdm.tqdm(total=len(prompts))
|
113 |
+
for batch in make_batches(prompts, args, task, max_positions):
|
114 |
+
src_tokens = batch.src_tokens
|
115 |
+
src_lengths = batch.src_lengths
|
116 |
+
src_tokens = src_tokens.cuda()
|
117 |
+
src_lengths = src_lengths.cuda()
|
118 |
+
|
119 |
+
sample = {
|
120 |
+
'net_input': {
|
121 |
+
'src_tokens': src_tokens,
|
122 |
+
'src_lengths': src_lengths,
|
123 |
+
},
|
124 |
+
}
|
125 |
+
|
126 |
+
results = []
|
127 |
+
translations = task.inference_step(generator, models, sample)
|
128 |
+
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
|
129 |
+
src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
|
130 |
+
results.append((i + start_id, src_tokens_i, hypos))
|
131 |
+
|
132 |
+
# sort output to match input order
|
133 |
+
for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
|
134 |
+
if src_dict is not None:
|
135 |
+
src_str = src_dict.string(
|
136 |
+
src_tokens, args.common_eval.post_process)
|
137 |
+
|
138 |
+
# Process top predictions
|
139 |
+
for hypo_id, hypo in enumerate(hypos):
|
140 |
+
_hypo_tokens, hypo_str, _alignment = utils.post_process_prediction(
|
141 |
+
hypo_tokens=hypo['tokens'].int().cpu(),
|
142 |
+
src_str=src_str,
|
143 |
+
alignment=hypo['alignment'],
|
144 |
+
align_dict=align_dict,
|
145 |
+
tgt_dict=tgt_dict,
|
146 |
+
remove_bpe=args.common_eval.post_process,
|
147 |
+
)
|
148 |
+
|
149 |
+
detok_hypo_str = hypo_str
|
150 |
+
utterance = detok_hypo_str
|
151 |
+
print(f'{seq_id[id]}__{hypo_id}|{utterance}', file=output_file)
|
152 |
+
pbar.update(1)
|
153 |
+
start_id += len(results)
|
154 |
+
|
155 |
+
# output_file.close()
|
156 |
+
|
157 |
+
|
158 |
+
def cli_main():
|
159 |
+
parser = options.get_interactive_generation_parser()
|
160 |
+
parser.add_argument('--prompts', type=str, default=None, required=True)
|
161 |
+
parser.add_argument('--output', type=str, default=None, required=True)
|
162 |
+
parser.add_argument('--debug', action='store_true')
|
163 |
+
parser.add_argument('--samples-per-prompt', type=int, default=1)
|
164 |
+
|
165 |
+
args = options.parse_args_and_arch(parser)
|
166 |
+
|
167 |
+
np.random.seed(args.seed)
|
168 |
+
utils.set_torch_seed(args.seed)
|
169 |
+
|
170 |
+
main(args)
|
171 |
+
|
172 |
+
|
173 |
+
if __name__ == '__main__':
|
174 |
+
cli_main()
|
fairseq/examples/textless_nlp/gslm/unit2speech/README.md
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Unit to Speech Model (unit2speech)
|
2 |
+
|
3 |
+
Unit to speech model is modified Tacotron2 model that learns to synthesize speech from discrete speech units. All models are trained on quantized [LJSpeech](https://keithito.com/LJ-Speech-Dataset/).
|
4 |
+
|
5 |
+
Upstream Units | Download Links | model md5
|
6 |
+
|-|-|-
|
7 |
+
Log Mel Filterbank + KM50 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km50/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km50/code_dict) | 932b3b8527c0125f5f964b57762eba49
|
8 |
+
Log Mel Filterbank + KM100 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km100/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km100/code_dict) | cde0b0d278a39011d0acbd5df27abdf4
|
9 |
+
Log Mel Filterbank + KM200 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km200/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km200/code_dict) | dba0f1d4de64bc7976718834010b23e7
|
10 |
+
Modified CPC + KM50 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km50/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km50/code_dict) | a585e8dd8890ea56164f17635dd8e613
|
11 |
+
Modified CPC + KM100 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km100/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km100/code_dict) | 5c0ee2869b4f483d17f37f1a41a548e0
|
12 |
+
Modified CPC + KM200 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km200/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km200/code_dict) | 2f0c9951cf37020d9464514bff48bc5d
|
13 |
+
HuBERT Base + KM50 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km50/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km50/code_dict) | 85ffce8baec5aa90035ab696fe676fce
|
14 |
+
HuBERT Base + KM100 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km100/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km100/code_dict) | df4a9c6ffd1bb00c91405432c234aba3
|
15 |
+
HuBERT Base + KM200 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km200/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km200/code_dict) | ac72f2c0c563589819bec116c7f8d274
|
16 |
+
wav2vec 2.0 Large + KM50 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km50/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km50/code_dict) | e3503d0ad822b2c24b89f68b857fedff
|
17 |
+
wav2vec 2.0 Large + KM100 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km100/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km100/code_dict) | eb3666e456ae4c96bf2a1eec825c13ed
|
18 |
+
wav2vec 2.0 Large + KM200 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km200/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km200/code_dict) | 777d343e963c4d64f04d78eef032f4e8
|
19 |
+
|
20 |
+
## Run inference using a unit2speech model
|
21 |
+
* Install librosa, unidecode and inflect using `pip install librosa, unidecode, inflect`
|
22 |
+
* Download [Waveglow checkpoint](https://dl.fbaipublicfiles.com/textless_nlp/gslm/waveglow_256channels_new.pt). This is the vocoder.
|
23 |
+
|
24 |
+
Sample commnd to run inference using trained unit2speech models. Please note that the quantized audio to synthesized should be using the same units as the unit2speech model was trained with.
|
25 |
+
```
|
26 |
+
FAIRSEQ_ROOT=<path_to_your_fairseq_repo_root>
|
27 |
+
TTS_MODEL_PATH=<unit2speech_model_file_path>
|
28 |
+
QUANTIZED_UNIT_PATH=<quantized_audio_file_path>
|
29 |
+
OUT_DIR=<dir_to_dump_synthesized_audio_files>
|
30 |
+
WAVEGLOW_PATH=<path_where_you_have_downloaded_waveglow_checkpoint>
|
31 |
+
CODE_DICT_PATH=<unit2speech_code_dict_path>
|
32 |
+
|
33 |
+
PYTHONPATH=${FAIRSEQ_ROOT}:${FAIRSEQ_ROOT}/examples/textless_nlp/gslm/unit2speech python ${FAIRSEQ_ROOT}/examples/textless_nlp/gslm/unit2speech/synthesize_audio_from_units.py \
|
34 |
+
--tts_model_path $TTS_MODEL_PATH \
|
35 |
+
--quantized_unit_path $QUANTIZED_UNIT_PATH \
|
36 |
+
--out_audio_dir $OUT_DIR \
|
37 |
+
--waveglow_path $WAVEGLOW_PATH \
|
38 |
+
--code_dict_path $CODE_DICT_PATH \
|
39 |
+
--max_decoder_steps 2000
|
40 |
+
```
|
fairseq/examples/textless_nlp/gslm/unit2speech/convert_to_16k.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shlex
|
3 |
+
import subprocess
|
4 |
+
import progressbar
|
5 |
+
from time import time
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
def find_all_files(path_dir, extension):
|
9 |
+
out = []
|
10 |
+
for root, dirs, filenames in os.walk(path_dir):
|
11 |
+
for f in filenames:
|
12 |
+
if f.endswith(extension):
|
13 |
+
out.append(((str(Path(f).stem)), os.path.join(root, f)))
|
14 |
+
return out
|
15 |
+
|
16 |
+
def convert16k(inputfile, outputfile16k):
|
17 |
+
command = ('sox -c 1 -b 16 {} -t wav {} rate 16k'.format(inputfile, outputfile16k))
|
18 |
+
subprocess.call(shlex.split(command))
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
import argparse
|
22 |
+
|
23 |
+
parser = argparse.ArgumentParser(description='Convert to wav 16k audio using sox.')
|
24 |
+
parser.add_argument('input_dir', type=str,
|
25 |
+
help='Path to the input dir.')
|
26 |
+
parser.add_argument('output_dir', type=str,
|
27 |
+
help='Path to the output dir.')
|
28 |
+
parser.add_argument('--extension', type=str, default='wav',
|
29 |
+
help='Audio file extension in the input. Default: mp3')
|
30 |
+
args = parser.parse_args()
|
31 |
+
|
32 |
+
# Find all sequences
|
33 |
+
print(f"Finding all audio files with extension '{args.extension}' from {args.input_dir}...")
|
34 |
+
audio_files = find_all_files(args.input_dir, args.extension)
|
35 |
+
print(f"Done! Found {len(audio_files)} files.")
|
36 |
+
|
37 |
+
# Convert to relative path
|
38 |
+
audio_files = [os.path.relpath(file[-1], start=args.input_dir) for file in audio_files]
|
39 |
+
|
40 |
+
# Create all the directories needed
|
41 |
+
rel_dirs_set = set([os.path.dirname(file) for file in audio_files])
|
42 |
+
for rel_dir in rel_dirs_set:
|
43 |
+
Path(os.path.join(args.output_dir, rel_dir)).mkdir(parents=True, exist_ok=True)
|
44 |
+
|
45 |
+
# Converting wavs files
|
46 |
+
print("Converting the audio to wav files...")
|
47 |
+
bar = progressbar.ProgressBar(maxval=len(audio_files))
|
48 |
+
bar.start()
|
49 |
+
start_time = time()
|
50 |
+
for index, file in enumerate(audio_files):
|
51 |
+
bar.update(index)
|
52 |
+
input_file = os.path.join(args.input_dir, file)
|
53 |
+
output_file = os.path.join(args.output_dir, os.path.splitext(file)[0]+".wav")
|
54 |
+
convert16k(input_file, output_file)
|
55 |
+
bar.finish()
|
56 |
+
print(f"...done {len(audio_files)} files in {time()-start_time} seconds.")
|
fairseq/examples/textless_nlp/gslm/unit2speech/glow.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# *****************************************************************************
|
2 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Redistribution and use in source and binary forms, with or without
|
5 |
+
# modification, are permitted provided that the following conditions are met:
|
6 |
+
# * Redistributions of source code must retain the above copyright
|
7 |
+
# notice, this list of conditions and the following disclaimer.
|
8 |
+
# * Redistributions in binary form must reproduce the above copyright
|
9 |
+
# notice, this list of conditions and the following disclaimer in the
|
10 |
+
# documentation and/or other materials provided with the distribution.
|
11 |
+
# * Neither the name of the NVIDIA CORPORATION nor the
|
12 |
+
# names of its contributors may be used to endorse or promote products
|
13 |
+
# derived from this software without specific prior written permission.
|
14 |
+
#
|
15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
16 |
+
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
17 |
+
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
18 |
+
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
19 |
+
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
20 |
+
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
21 |
+
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
22 |
+
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
23 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
24 |
+
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
25 |
+
#
|
26 |
+
# *****************************************************************************
|
27 |
+
import copy
|
28 |
+
import torch
|
29 |
+
from torch.autograd import Variable
|
30 |
+
import torch.nn.functional as F
|
31 |
+
|
32 |
+
|
33 |
+
@torch.jit.script
|
34 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
35 |
+
n_channels_int = n_channels[0]
|
36 |
+
in_act = input_a+input_b
|
37 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
38 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
39 |
+
acts = t_act * s_act
|
40 |
+
return acts
|
41 |
+
|
42 |
+
|
43 |
+
class WaveGlowLoss(torch.nn.Module):
|
44 |
+
def __init__(self, sigma=1.0):
|
45 |
+
super(WaveGlowLoss, self).__init__()
|
46 |
+
self.sigma = sigma
|
47 |
+
|
48 |
+
def forward(self, model_output):
|
49 |
+
z, log_s_list, log_det_W_list = model_output
|
50 |
+
for i, log_s in enumerate(log_s_list):
|
51 |
+
if i == 0:
|
52 |
+
log_s_total = torch.sum(log_s)
|
53 |
+
log_det_W_total = log_det_W_list[i]
|
54 |
+
else:
|
55 |
+
log_s_total = log_s_total + torch.sum(log_s)
|
56 |
+
log_det_W_total += log_det_W_list[i]
|
57 |
+
|
58 |
+
loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - log_s_total - log_det_W_total
|
59 |
+
return loss/(z.size(0)*z.size(1)*z.size(2))
|
60 |
+
|
61 |
+
|
62 |
+
class Invertible1x1Conv(torch.nn.Module):
|
63 |
+
"""
|
64 |
+
The layer outputs both the convolution, and the log determinant
|
65 |
+
of its weight matrix. If reverse=True it does convolution with
|
66 |
+
inverse
|
67 |
+
"""
|
68 |
+
def __init__(self, c):
|
69 |
+
super(Invertible1x1Conv, self).__init__()
|
70 |
+
self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0,
|
71 |
+
bias=False)
|
72 |
+
|
73 |
+
# Sample a random orthonormal matrix to initialize weights
|
74 |
+
_qr = torch.linalg.qr if torch.__version__ >= "1.8" else torch.qr
|
75 |
+
W = _qr(torch.FloatTensor(c, c).normal_())[0]
|
76 |
+
|
77 |
+
# Ensure determinant is 1.0 not -1.0
|
78 |
+
if torch.det(W) < 0:
|
79 |
+
W[:,0] = -1*W[:,0]
|
80 |
+
W = W.view(c, c, 1)
|
81 |
+
self.conv.weight.data = W
|
82 |
+
|
83 |
+
def forward(self, z, reverse=False):
|
84 |
+
# shape
|
85 |
+
batch_size, group_size, n_of_groups = z.size()
|
86 |
+
|
87 |
+
W = self.conv.weight.squeeze()
|
88 |
+
|
89 |
+
if reverse:
|
90 |
+
if not hasattr(self, 'W_inverse'):
|
91 |
+
# Reverse computation
|
92 |
+
W_inverse = W.float().inverse()
|
93 |
+
W_inverse = Variable(W_inverse[..., None])
|
94 |
+
if z.type() == 'torch.cuda.HalfTensor':
|
95 |
+
W_inverse = W_inverse.half()
|
96 |
+
self.W_inverse = W_inverse
|
97 |
+
z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
|
98 |
+
return z
|
99 |
+
else:
|
100 |
+
# Forward computation
|
101 |
+
log_det_W = batch_size * n_of_groups * torch.logdet(W)
|
102 |
+
z = self.conv(z)
|
103 |
+
return z, log_det_W
|
104 |
+
|
105 |
+
|
106 |
+
class WN(torch.nn.Module):
|
107 |
+
"""
|
108 |
+
This is the WaveNet like layer for the affine coupling. The primary difference
|
109 |
+
from WaveNet is the convolutions need not be causal. There is also no dilation
|
110 |
+
size reset. The dilation only doubles on each layer
|
111 |
+
"""
|
112 |
+
def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
|
113 |
+
kernel_size):
|
114 |
+
super(WN, self).__init__()
|
115 |
+
assert(kernel_size % 2 == 1)
|
116 |
+
assert(n_channels % 2 == 0)
|
117 |
+
self.n_layers = n_layers
|
118 |
+
self.n_channels = n_channels
|
119 |
+
self.in_layers = torch.nn.ModuleList()
|
120 |
+
self.res_skip_layers = torch.nn.ModuleList()
|
121 |
+
|
122 |
+
start = torch.nn.Conv1d(n_in_channels, n_channels, 1)
|
123 |
+
start = torch.nn.utils.weight_norm(start, name='weight')
|
124 |
+
self.start = start
|
125 |
+
|
126 |
+
# Initializing last layer to 0 makes the affine coupling layers
|
127 |
+
# do nothing at first. This helps with training stability
|
128 |
+
end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1)
|
129 |
+
end.weight.data.zero_()
|
130 |
+
end.bias.data.zero_()
|
131 |
+
self.end = end
|
132 |
+
|
133 |
+
cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1)
|
134 |
+
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
|
135 |
+
|
136 |
+
for i in range(n_layers):
|
137 |
+
dilation = 2 ** i
|
138 |
+
padding = int((kernel_size*dilation - dilation)/2)
|
139 |
+
in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size,
|
140 |
+
dilation=dilation, padding=padding)
|
141 |
+
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
|
142 |
+
self.in_layers.append(in_layer)
|
143 |
+
|
144 |
+
|
145 |
+
# last one is not necessary
|
146 |
+
if i < n_layers - 1:
|
147 |
+
res_skip_channels = 2*n_channels
|
148 |
+
else:
|
149 |
+
res_skip_channels = n_channels
|
150 |
+
res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
|
151 |
+
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
|
152 |
+
self.res_skip_layers.append(res_skip_layer)
|
153 |
+
|
154 |
+
def forward(self, forward_input):
|
155 |
+
audio, spect = forward_input
|
156 |
+
audio = self.start(audio)
|
157 |
+
output = torch.zeros_like(audio)
|
158 |
+
n_channels_tensor = torch.IntTensor([self.n_channels])
|
159 |
+
|
160 |
+
spect = self.cond_layer(spect)
|
161 |
+
|
162 |
+
for i in range(self.n_layers):
|
163 |
+
spect_offset = i*2*self.n_channels
|
164 |
+
acts = fused_add_tanh_sigmoid_multiply(
|
165 |
+
self.in_layers[i](audio),
|
166 |
+
spect[:,spect_offset:spect_offset+2*self.n_channels,:],
|
167 |
+
n_channels_tensor)
|
168 |
+
|
169 |
+
res_skip_acts = self.res_skip_layers[i](acts)
|
170 |
+
if i < self.n_layers - 1:
|
171 |
+
audio = audio + res_skip_acts[:,:self.n_channels,:]
|
172 |
+
output = output + res_skip_acts[:,self.n_channels:,:]
|
173 |
+
else:
|
174 |
+
output = output + res_skip_acts
|
175 |
+
|
176 |
+
return self.end(output)
|
177 |
+
|
178 |
+
|
179 |
+
class WaveGlow(torch.nn.Module):
|
180 |
+
def __init__(self, n_mel_channels, n_flows, n_group, n_early_every,
|
181 |
+
n_early_size, WN_config):
|
182 |
+
super(WaveGlow, self).__init__()
|
183 |
+
|
184 |
+
self.upsample = torch.nn.ConvTranspose1d(n_mel_channels,
|
185 |
+
n_mel_channels,
|
186 |
+
1024, stride=256)
|
187 |
+
assert(n_group % 2 == 0)
|
188 |
+
self.n_flows = n_flows
|
189 |
+
self.n_group = n_group
|
190 |
+
self.n_early_every = n_early_every
|
191 |
+
self.n_early_size = n_early_size
|
192 |
+
self.WN = torch.nn.ModuleList()
|
193 |
+
self.convinv = torch.nn.ModuleList()
|
194 |
+
|
195 |
+
n_half = int(n_group/2)
|
196 |
+
|
197 |
+
# Set up layers with the right sizes based on how many dimensions
|
198 |
+
# have been output already
|
199 |
+
n_remaining_channels = n_group
|
200 |
+
for k in range(n_flows):
|
201 |
+
if k % self.n_early_every == 0 and k > 0:
|
202 |
+
n_half = n_half - int(self.n_early_size/2)
|
203 |
+
n_remaining_channels = n_remaining_channels - self.n_early_size
|
204 |
+
self.convinv.append(Invertible1x1Conv(n_remaining_channels))
|
205 |
+
self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config))
|
206 |
+
self.n_remaining_channels = n_remaining_channels # Useful during inference
|
207 |
+
|
208 |
+
def forward(self, forward_input):
|
209 |
+
"""
|
210 |
+
forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames
|
211 |
+
forward_input[1] = audio: batch x time
|
212 |
+
"""
|
213 |
+
spect, audio = forward_input
|
214 |
+
|
215 |
+
# Upsample spectrogram to size of audio
|
216 |
+
spect = self.upsample(spect)
|
217 |
+
assert(spect.size(2) >= audio.size(1))
|
218 |
+
if spect.size(2) > audio.size(1):
|
219 |
+
spect = spect[:, :, :audio.size(1)]
|
220 |
+
|
221 |
+
spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
|
222 |
+
spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
|
223 |
+
|
224 |
+
audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1)
|
225 |
+
output_audio = []
|
226 |
+
log_s_list = []
|
227 |
+
log_det_W_list = []
|
228 |
+
|
229 |
+
for k in range(self.n_flows):
|
230 |
+
if k % self.n_early_every == 0 and k > 0:
|
231 |
+
output_audio.append(audio[:,:self.n_early_size,:])
|
232 |
+
audio = audio[:,self.n_early_size:,:]
|
233 |
+
|
234 |
+
audio, log_det_W = self.convinv[k](audio)
|
235 |
+
log_det_W_list.append(log_det_W)
|
236 |
+
|
237 |
+
n_half = int(audio.size(1)/2)
|
238 |
+
audio_0 = audio[:,:n_half,:]
|
239 |
+
audio_1 = audio[:,n_half:,:]
|
240 |
+
|
241 |
+
output = self.WN[k]((audio_0, spect))
|
242 |
+
log_s = output[:, n_half:, :]
|
243 |
+
b = output[:, :n_half, :]
|
244 |
+
audio_1 = torch.exp(log_s)*audio_1 + b
|
245 |
+
log_s_list.append(log_s)
|
246 |
+
|
247 |
+
audio = torch.cat([audio_0, audio_1],1)
|
248 |
+
|
249 |
+
output_audio.append(audio)
|
250 |
+
return torch.cat(output_audio,1), log_s_list, log_det_W_list
|
251 |
+
|
252 |
+
def infer(self, spect, sigma=1.0):
|
253 |
+
spect = self.upsample(spect)
|
254 |
+
# trim conv artifacts. maybe pad spec to kernel multiple
|
255 |
+
time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
|
256 |
+
spect = spect[:, :, :-time_cutoff]
|
257 |
+
|
258 |
+
spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
|
259 |
+
spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
|
260 |
+
|
261 |
+
if spect.type() == 'torch.cuda.HalfTensor':
|
262 |
+
audio = torch.cuda.HalfTensor(spect.size(0),
|
263 |
+
self.n_remaining_channels,
|
264 |
+
spect.size(2)).normal_()
|
265 |
+
else:
|
266 |
+
audio = torch.cuda.FloatTensor(spect.size(0),
|
267 |
+
self.n_remaining_channels,
|
268 |
+
spect.size(2)).normal_()
|
269 |
+
|
270 |
+
audio = torch.autograd.Variable(sigma*audio)
|
271 |
+
|
272 |
+
for k in reversed(range(self.n_flows)):
|
273 |
+
n_half = int(audio.size(1)/2)
|
274 |
+
audio_0 = audio[:,:n_half,:]
|
275 |
+
audio_1 = audio[:,n_half:,:]
|
276 |
+
|
277 |
+
output = self.WN[k]((audio_0, spect))
|
278 |
+
|
279 |
+
s = output[:, n_half:, :]
|
280 |
+
b = output[:, :n_half, :]
|
281 |
+
audio_1 = (audio_1 - b)/torch.exp(s)
|
282 |
+
audio = torch.cat([audio_0, audio_1],1)
|
283 |
+
|
284 |
+
audio = self.convinv[k](audio, reverse=True)
|
285 |
+
|
286 |
+
if k % self.n_early_every == 0 and k > 0:
|
287 |
+
if spect.type() == 'torch.cuda.HalfTensor':
|
288 |
+
z = torch.cuda.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
|
289 |
+
else:
|
290 |
+
z = torch.cuda.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
|
291 |
+
audio = torch.cat((sigma*z, audio),1)
|
292 |
+
|
293 |
+
audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data
|
294 |
+
return audio
|
295 |
+
|
296 |
+
@staticmethod
|
297 |
+
def remove_weightnorm(model):
|
298 |
+
waveglow = model
|
299 |
+
for WN in waveglow.WN:
|
300 |
+
WN.start = torch.nn.utils.remove_weight_norm(WN.start)
|
301 |
+
WN.in_layers = remove(WN.in_layers)
|
302 |
+
WN.cond_layer = torch.nn.utils.remove_weight_norm(WN.cond_layer)
|
303 |
+
WN.res_skip_layers = remove(WN.res_skip_layers)
|
304 |
+
return waveglow
|
305 |
+
|
306 |
+
|
307 |
+
def remove(conv_list):
|
308 |
+
new_conv_list = torch.nn.ModuleList()
|
309 |
+
for old_conv in conv_list:
|
310 |
+
old_conv = torch.nn.utils.remove_weight_norm(old_conv)
|
311 |
+
new_conv_list.append(old_conv)
|
312 |
+
return new_conv_list
|
fairseq/examples/textless_nlp/gslm/unit2speech/multiproc.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import sys
|
5 |
+
import subprocess
|
6 |
+
|
7 |
+
argslist = list(sys.argv)[1:]
|
8 |
+
log_dir = argslist[-1]
|
9 |
+
num_gpus = torch.cuda.device_count()
|
10 |
+
argslist.append('--n_gpus={}'.format(num_gpus))
|
11 |
+
workers = []
|
12 |
+
job_id = time.strftime("%Y_%m_%d-%H%M%S")
|
13 |
+
argslist.append("--group_name=group_{}".format(job_id))
|
14 |
+
|
15 |
+
print("GPU log directory is {}".format(log_dir))
|
16 |
+
os.makedirs(log_dir, exist_ok=True)
|
17 |
+
for i in range(num_gpus):
|
18 |
+
argslist.append('--rank={}'.format(i))
|
19 |
+
stdout = None if i == 0 else open("{}/{}_GPU_{}.log".format(log_dir, job_id, i),
|
20 |
+
"w")
|
21 |
+
print(argslist)
|
22 |
+
p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout)
|
23 |
+
workers.append(p)
|
24 |
+
argslist = argslist[:-1]
|
25 |
+
|
26 |
+
for p in workers:
|
27 |
+
p.wait()
|
fairseq/examples/textless_nlp/gslm/unit2speech/synthesize_audio_from_units.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
|
10 |
+
import soundfile as sf
|
11 |
+
from examples.textless_nlp.gslm.unit2speech.tts_data import (
|
12 |
+
TacotronInputDataset,
|
13 |
+
)
|
14 |
+
from examples.textless_nlp.gslm.unit2speech.utils import (
|
15 |
+
load_quantized_audio_from_file,
|
16 |
+
load_tacotron,
|
17 |
+
load_waveglow,
|
18 |
+
synthesize_audio,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
def get_logger():
|
23 |
+
log_format = "[%(asctime)s] [%(levelname)s]: %(message)s"
|
24 |
+
logging.basicConfig(format=log_format, level=logging.INFO)
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
return logger
|
27 |
+
|
28 |
+
|
29 |
+
def get_parser():
|
30 |
+
parser = argparse.ArgumentParser(
|
31 |
+
description="Wav2Vec 2.0 speech generator."
|
32 |
+
)
|
33 |
+
parser.add_argument(
|
34 |
+
"--quantized_unit_path",
|
35 |
+
type=str,
|
36 |
+
help="K-means model file path to use for inference",
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--tts_model_path",
|
40 |
+
type=str,
|
41 |
+
help="TTS model file path to use for inference",
|
42 |
+
)
|
43 |
+
parser.add_argument(
|
44 |
+
"--waveglow_path",
|
45 |
+
type=str,
|
46 |
+
help="Path to the waveglow checkpoint (vocoder).",
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--code_dict_path",
|
50 |
+
type=str,
|
51 |
+
help="Code dict file path to use for inference",
|
52 |
+
)
|
53 |
+
parser.add_argument("--max_decoder_steps", type=int, default=2000)
|
54 |
+
parser.add_argument("--denoiser_strength", type=float, default=0.1)
|
55 |
+
parser.add_argument(
|
56 |
+
"--out_audio_dir",
|
57 |
+
type=str,
|
58 |
+
help="Output directory to dump audio files",
|
59 |
+
)
|
60 |
+
|
61 |
+
return parser
|
62 |
+
|
63 |
+
|
64 |
+
def main(args, logger):
|
65 |
+
# Load quantized audio
|
66 |
+
logger.info(f"Loading quantized audio from {args.quantized_unit_path}...")
|
67 |
+
names_batch, quantized_units_batch = load_quantized_audio_from_file(
|
68 |
+
file_path=args.quantized_unit_path
|
69 |
+
)
|
70 |
+
|
71 |
+
logger.info(f"Loading TTS model from {args.tts_model_path}...")
|
72 |
+
tacotron_model, sample_rate, hparams = load_tacotron(
|
73 |
+
tacotron_model_path=args.tts_model_path,
|
74 |
+
max_decoder_steps=args.max_decoder_steps,
|
75 |
+
)
|
76 |
+
|
77 |
+
logger.info(f"Loading Waveglow model from {args.waveglow_path}...")
|
78 |
+
waveglow, denoiser = load_waveglow(waveglow_path=args.waveglow_path)
|
79 |
+
|
80 |
+
if not os.path.exists(hparams.code_dict):
|
81 |
+
hparams.code_dict = args.code_dict_path
|
82 |
+
tts_dataset = TacotronInputDataset(hparams)
|
83 |
+
|
84 |
+
for name, quantized_units in zip(names_batch, quantized_units_batch):
|
85 |
+
quantized_units_str = " ".join(map(str, quantized_units))
|
86 |
+
tts_input = tts_dataset.get_tensor(quantized_units_str)
|
87 |
+
mel, aud, aud_dn, has_eos = synthesize_audio(
|
88 |
+
tacotron_model,
|
89 |
+
waveglow,
|
90 |
+
denoiser,
|
91 |
+
tts_input.unsqueeze(0),
|
92 |
+
strength=args.denoiser_strength,
|
93 |
+
)
|
94 |
+
out_file_path = os.path.join(args.out_audio_dir, f"{name}.wav")
|
95 |
+
sf.write(
|
96 |
+
f"{out_file_path}", aud_dn[0].cpu().float().numpy(), sample_rate
|
97 |
+
)
|
98 |
+
|
99 |
+
|
100 |
+
if __name__ == "__main__":
|
101 |
+
parser = get_parser()
|
102 |
+
args = parser.parse_args()
|
103 |
+
logger = get_logger()
|
104 |
+
logger.info(args)
|
105 |
+
main(args, logger)
|
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/__init__.py
ADDED
File without changes
|
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/audio_processing.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from scipy.signal import get_window
|
4 |
+
import librosa.util as librosa_util
|
5 |
+
|
6 |
+
|
7 |
+
def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
|
8 |
+
n_fft=800, dtype=np.float32, norm=None):
|
9 |
+
"""
|
10 |
+
# from librosa 0.6
|
11 |
+
Compute the sum-square envelope of a window function at a given hop length.
|
12 |
+
|
13 |
+
This is used to estimate modulation effects induced by windowing
|
14 |
+
observations in short-time fourier transforms.
|
15 |
+
|
16 |
+
Parameters
|
17 |
+
----------
|
18 |
+
window : string, tuple, number, callable, or list-like
|
19 |
+
Window specification, as in `get_window`
|
20 |
+
|
21 |
+
n_frames : int > 0
|
22 |
+
The number of analysis frames
|
23 |
+
|
24 |
+
hop_length : int > 0
|
25 |
+
The number of samples to advance between frames
|
26 |
+
|
27 |
+
win_length : [optional]
|
28 |
+
The length of the window function. By default, this matches `n_fft`.
|
29 |
+
|
30 |
+
n_fft : int > 0
|
31 |
+
The length of each analysis frame.
|
32 |
+
|
33 |
+
dtype : np.dtype
|
34 |
+
The data type of the output
|
35 |
+
|
36 |
+
Returns
|
37 |
+
-------
|
38 |
+
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
|
39 |
+
The sum-squared envelope of the window function
|
40 |
+
"""
|
41 |
+
if win_length is None:
|
42 |
+
win_length = n_fft
|
43 |
+
|
44 |
+
n = n_fft + hop_length * (n_frames - 1)
|
45 |
+
x = np.zeros(n, dtype=dtype)
|
46 |
+
|
47 |
+
# Compute the squared window at the desired length
|
48 |
+
win_sq = get_window(window, win_length, fftbins=True)
|
49 |
+
win_sq = librosa_util.normalize(win_sq, norm=norm)**2
|
50 |
+
win_sq = librosa_util.pad_center(win_sq, n_fft)
|
51 |
+
|
52 |
+
# Fill the envelope
|
53 |
+
for i in range(n_frames):
|
54 |
+
sample = i * hop_length
|
55 |
+
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
def griffin_lim(magnitudes, stft_fn, n_iters=30):
|
60 |
+
"""
|
61 |
+
PARAMS
|
62 |
+
------
|
63 |
+
magnitudes: spectrogram magnitudes
|
64 |
+
stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
|
65 |
+
"""
|
66 |
+
|
67 |
+
angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
|
68 |
+
angles = angles.astype(np.float32)
|
69 |
+
angles = torch.autograd.Variable(torch.from_numpy(angles))
|
70 |
+
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
|
71 |
+
|
72 |
+
for i in range(n_iters):
|
73 |
+
_, angles = stft_fn.transform(signal)
|
74 |
+
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
|
75 |
+
return signal
|
76 |
+
|
77 |
+
|
78 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
79 |
+
"""
|
80 |
+
PARAMS
|
81 |
+
------
|
82 |
+
C: compression factor
|
83 |
+
"""
|
84 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
85 |
+
|
86 |
+
|
87 |
+
def dynamic_range_decompression(x, C=1):
|
88 |
+
"""
|
89 |
+
PARAMS
|
90 |
+
------
|
91 |
+
C: compression factor used to compress
|
92 |
+
"""
|
93 |
+
return torch.exp(x) / C
|
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cleaners.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/keithito/tacotron """
|
2 |
+
|
3 |
+
'''
|
4 |
+
Cleaners are transformations that run over the input text at both training and eval time.
|
5 |
+
|
6 |
+
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
7 |
+
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
|
8 |
+
1. "english_cleaners" for English text
|
9 |
+
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
10 |
+
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
11 |
+
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
12 |
+
the symbols in symbols.py to match your data).
|
13 |
+
'''
|
14 |
+
|
15 |
+
import re
|
16 |
+
from unidecode import unidecode
|
17 |
+
from .numbers import normalize_numbers
|
18 |
+
|
19 |
+
|
20 |
+
# Regular expression matching whitespace:
|
21 |
+
_whitespace_re = re.compile(r'\s+')
|
22 |
+
|
23 |
+
# List of (regular expression, replacement) pairs for abbreviations:
|
24 |
+
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
25 |
+
('mrs', 'misess'),
|
26 |
+
('mr', 'mister'),
|
27 |
+
('dr', 'doctor'),
|
28 |
+
('st', 'saint'),
|
29 |
+
('co', 'company'),
|
30 |
+
('jr', 'junior'),
|
31 |
+
('maj', 'major'),
|
32 |
+
('gen', 'general'),
|
33 |
+
('drs', 'doctors'),
|
34 |
+
('rev', 'reverend'),
|
35 |
+
('lt', 'lieutenant'),
|
36 |
+
('hon', 'honorable'),
|
37 |
+
('sgt', 'sergeant'),
|
38 |
+
('capt', 'captain'),
|
39 |
+
('esq', 'esquire'),
|
40 |
+
('ltd', 'limited'),
|
41 |
+
('col', 'colonel'),
|
42 |
+
('ft', 'fort'),
|
43 |
+
]]
|
44 |
+
|
45 |
+
|
46 |
+
def expand_abbreviations(text):
|
47 |
+
for regex, replacement in _abbreviations:
|
48 |
+
text = re.sub(regex, replacement, text)
|
49 |
+
return text
|
50 |
+
|
51 |
+
|
52 |
+
def expand_numbers(text):
|
53 |
+
return normalize_numbers(text)
|
54 |
+
|
55 |
+
|
56 |
+
def lowercase(text):
|
57 |
+
return text.lower()
|
58 |
+
|
59 |
+
|
60 |
+
def collapse_whitespace(text):
|
61 |
+
return re.sub(_whitespace_re, ' ', text)
|
62 |
+
|
63 |
+
|
64 |
+
def convert_to_ascii(text):
|
65 |
+
return unidecode(text)
|
66 |
+
|
67 |
+
|
68 |
+
def basic_cleaners(text):
|
69 |
+
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
|
70 |
+
text = lowercase(text)
|
71 |
+
text = collapse_whitespace(text)
|
72 |
+
return text
|
73 |
+
|
74 |
+
|
75 |
+
def transliteration_cleaners(text):
|
76 |
+
'''Pipeline for non-English text that transliterates to ASCII.'''
|
77 |
+
text = convert_to_ascii(text)
|
78 |
+
text = lowercase(text)
|
79 |
+
text = collapse_whitespace(text)
|
80 |
+
return text
|
81 |
+
|
82 |
+
|
83 |
+
def english_cleaners(text):
|
84 |
+
'''Pipeline for English text, including number and abbreviation expansion.'''
|
85 |
+
text = convert_to_ascii(text)
|
86 |
+
text = lowercase(text)
|
87 |
+
text = expand_numbers(text)
|
88 |
+
text = expand_abbreviations(text)
|
89 |
+
text = collapse_whitespace(text)
|
90 |
+
return text
|
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cmudict.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/keithito/tacotron """
|
2 |
+
|
3 |
+
import re
|
4 |
+
|
5 |
+
|
6 |
+
valid_symbols = [
|
7 |
+
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
|
8 |
+
'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
|
9 |
+
'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
|
10 |
+
'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
|
11 |
+
'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
|
12 |
+
'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
|
13 |
+
'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
|
14 |
+
]
|
15 |
+
|
16 |
+
_valid_symbol_set = set(valid_symbols)
|
17 |
+
|
18 |
+
|
19 |
+
class CMUDict:
|
20 |
+
'''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict'''
|
21 |
+
def __init__(self, file_or_path, keep_ambiguous=True):
|
22 |
+
if isinstance(file_or_path, str):
|
23 |
+
with open(file_or_path, encoding='latin-1') as f:
|
24 |
+
entries = _parse_cmudict(f)
|
25 |
+
else:
|
26 |
+
entries = _parse_cmudict(file_or_path)
|
27 |
+
if not keep_ambiguous:
|
28 |
+
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
|
29 |
+
self._entries = entries
|
30 |
+
|
31 |
+
|
32 |
+
def __len__(self):
|
33 |
+
return len(self._entries)
|
34 |
+
|
35 |
+
|
36 |
+
def lookup(self, word):
|
37 |
+
'''Returns list of ARPAbet pronunciations of the given word.'''
|
38 |
+
return self._entries.get(word.upper())
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
_alt_re = re.compile(r'\([0-9]+\)')
|
43 |
+
|
44 |
+
|
45 |
+
def _parse_cmudict(file):
|
46 |
+
cmudict = {}
|
47 |
+
for line in file:
|
48 |
+
if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
|
49 |
+
parts = line.split(' ')
|
50 |
+
word = re.sub(_alt_re, '', parts[0])
|
51 |
+
pronunciation = _get_pronunciation(parts[1])
|
52 |
+
if pronunciation:
|
53 |
+
if word in cmudict:
|
54 |
+
cmudict[word].append(pronunciation)
|
55 |
+
else:
|
56 |
+
cmudict[word] = [pronunciation]
|
57 |
+
return cmudict
|
58 |
+
|
59 |
+
|
60 |
+
def _get_pronunciation(s):
|
61 |
+
parts = s.strip().split(' ')
|
62 |
+
for part in parts:
|
63 |
+
if part not in _valid_symbol_set:
|
64 |
+
return None
|
65 |
+
return ' '.join(parts)
|
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/layers.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from librosa.filters import mel as librosa_mel_fn
|
3 |
+
from .audio_processing import dynamic_range_compression
|
4 |
+
from .audio_processing import dynamic_range_decompression
|
5 |
+
from .stft import STFT
|
6 |
+
from .utils import get_mask_from_lengths
|
7 |
+
|
8 |
+
|
9 |
+
class LinearNorm(torch.nn.Module):
|
10 |
+
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
11 |
+
super(LinearNorm, self).__init__()
|
12 |
+
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
13 |
+
|
14 |
+
torch.nn.init.xavier_uniform_(
|
15 |
+
self.linear_layer.weight,
|
16 |
+
gain=torch.nn.init.calculate_gain(w_init_gain))
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
return self.linear_layer(x)
|
20 |
+
|
21 |
+
|
22 |
+
class ConvNorm(torch.nn.Module):
|
23 |
+
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
|
24 |
+
padding=None, dilation=1, bias=True, w_init_gain='linear'):
|
25 |
+
super(ConvNorm, self).__init__()
|
26 |
+
if padding is None:
|
27 |
+
assert(kernel_size % 2 == 1)
|
28 |
+
padding = int(dilation * (kernel_size - 1) / 2)
|
29 |
+
|
30 |
+
self.conv = torch.nn.Conv1d(in_channels, out_channels,
|
31 |
+
kernel_size=kernel_size, stride=stride,
|
32 |
+
padding=padding, dilation=dilation,
|
33 |
+
bias=bias)
|
34 |
+
|
35 |
+
torch.nn.init.xavier_uniform_(
|
36 |
+
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
|
37 |
+
|
38 |
+
def forward(self, signal):
|
39 |
+
conv_signal = self.conv(signal)
|
40 |
+
return conv_signal
|
41 |
+
|
42 |
+
|
43 |
+
class GlobalAvgPool(torch.nn.Module):
|
44 |
+
def __init__(self):
|
45 |
+
super(GlobalAvgPool, self).__init__()
|
46 |
+
|
47 |
+
def forward(self, x, lengths=None):
|
48 |
+
"""Average pooling across time steps (dim=1) with optionally lengths.
|
49 |
+
Args:
|
50 |
+
x: torch.Tensor of shape (N, T, ...)
|
51 |
+
lengths: None or torch.Tensor of shape (N,)
|
52 |
+
dim: dimension to pool
|
53 |
+
"""
|
54 |
+
if lengths is None:
|
55 |
+
return x.mean(dim=1, keepdim=False)
|
56 |
+
else:
|
57 |
+
mask = get_mask_from_lengths(lengths).type(x.type()).to(x.device)
|
58 |
+
mask_shape = list(mask.size()) + [1 for _ in range(x.ndimension()-2)]
|
59 |
+
mask = mask.reshape(*mask_shape)
|
60 |
+
numer = (x * mask).sum(dim=1, keepdim=False)
|
61 |
+
denom = mask.sum(dim=1, keepdim=False)
|
62 |
+
return numer / denom
|
63 |
+
|
64 |
+
|
65 |
+
class TacotronSTFT(torch.nn.Module):
|
66 |
+
def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
|
67 |
+
n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
|
68 |
+
mel_fmax=8000.0):
|
69 |
+
super(TacotronSTFT, self).__init__()
|
70 |
+
self.n_mel_channels = n_mel_channels
|
71 |
+
self.sampling_rate = sampling_rate
|
72 |
+
self.stft_fn = STFT(filter_length, hop_length, win_length)
|
73 |
+
mel_basis = librosa_mel_fn(
|
74 |
+
sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
|
75 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
76 |
+
self.register_buffer('mel_basis', mel_basis)
|
77 |
+
|
78 |
+
def spectral_normalize(self, magnitudes):
|
79 |
+
output = dynamic_range_compression(magnitudes)
|
80 |
+
return output
|
81 |
+
|
82 |
+
def spectral_de_normalize(self, magnitudes):
|
83 |
+
output = dynamic_range_decompression(magnitudes)
|
84 |
+
return output
|
85 |
+
|
86 |
+
def mel_spectrogram(self, y):
|
87 |
+
"""Computes mel-spectrograms from a batch of waves
|
88 |
+
PARAMS
|
89 |
+
------
|
90 |
+
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
|
91 |
+
|
92 |
+
RETURNS
|
93 |
+
-------
|
94 |
+
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
|
95 |
+
"""
|
96 |
+
assert(torch.min(y.data) >= -1)
|
97 |
+
assert(torch.max(y.data) <= 1)
|
98 |
+
|
99 |
+
magnitudes, phases = self.stft_fn.transform(y)
|
100 |
+
magnitudes = magnitudes.data
|
101 |
+
mel_output = torch.matmul(self.mel_basis, magnitudes)
|
102 |
+
mel_output = self.spectral_normalize(mel_output)
|
103 |
+
return mel_output
|
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/model.py
ADDED
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import sqrt
|
2 |
+
import torch
|
3 |
+
import torch.distributions as distr
|
4 |
+
from torch.autograd import Variable
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
from .layers import ConvNorm, LinearNorm, GlobalAvgPool
|
8 |
+
from .utils import to_gpu, get_mask_from_lengths
|
9 |
+
|
10 |
+
|
11 |
+
class LocationLayer(nn.Module):
|
12 |
+
def __init__(self, attention_n_filters, attention_kernel_size,
|
13 |
+
attention_dim):
|
14 |
+
super(LocationLayer, self).__init__()
|
15 |
+
padding = int((attention_kernel_size - 1) / 2)
|
16 |
+
self.location_conv = ConvNorm(2, attention_n_filters,
|
17 |
+
kernel_size=attention_kernel_size,
|
18 |
+
padding=padding, bias=False, stride=1,
|
19 |
+
dilation=1)
|
20 |
+
self.location_dense = LinearNorm(attention_n_filters, attention_dim,
|
21 |
+
bias=False, w_init_gain='tanh')
|
22 |
+
|
23 |
+
def forward(self, attention_weights_cat):
|
24 |
+
processed_attention = self.location_conv(attention_weights_cat)
|
25 |
+
processed_attention = processed_attention.transpose(1, 2)
|
26 |
+
processed_attention = self.location_dense(processed_attention)
|
27 |
+
return processed_attention
|
28 |
+
|
29 |
+
|
30 |
+
class Attention(nn.Module):
|
31 |
+
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
32 |
+
attention_location_n_filters, attention_location_kernel_size):
|
33 |
+
super(Attention, self).__init__()
|
34 |
+
self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
|
35 |
+
bias=False, w_init_gain='tanh')
|
36 |
+
self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
|
37 |
+
w_init_gain='tanh')
|
38 |
+
self.v = LinearNorm(attention_dim, 1, bias=False)
|
39 |
+
self.location_layer = LocationLayer(attention_location_n_filters,
|
40 |
+
attention_location_kernel_size,
|
41 |
+
attention_dim)
|
42 |
+
self.score_mask_value = -float("inf")
|
43 |
+
|
44 |
+
def get_alignment_energies(self, query, processed_memory,
|
45 |
+
attention_weights_cat):
|
46 |
+
"""
|
47 |
+
PARAMS
|
48 |
+
------
|
49 |
+
query: decoder output (batch, n_mel_channels * n_frames_per_step)
|
50 |
+
processed_memory: processed encoder outputs (B, T_in, attention_dim)
|
51 |
+
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
|
52 |
+
|
53 |
+
RETURNS
|
54 |
+
-------
|
55 |
+
alignment (batch, max_time)
|
56 |
+
"""
|
57 |
+
|
58 |
+
processed_query = self.query_layer(query.unsqueeze(1))
|
59 |
+
processed_attention_weights = self.location_layer(attention_weights_cat)
|
60 |
+
energies = self.v(torch.tanh(
|
61 |
+
processed_query + processed_attention_weights + processed_memory))
|
62 |
+
|
63 |
+
energies = energies.squeeze(-1)
|
64 |
+
return energies
|
65 |
+
|
66 |
+
def forward(self, attention_hidden_state, memory, processed_memory,
|
67 |
+
attention_weights_cat, mask):
|
68 |
+
"""
|
69 |
+
PARAMS
|
70 |
+
------
|
71 |
+
attention_hidden_state: attention rnn last output
|
72 |
+
memory: encoder outputs
|
73 |
+
processed_memory: processed encoder outputs
|
74 |
+
attention_weights_cat: previous and cummulative attention weights
|
75 |
+
mask: binary mask for padded data
|
76 |
+
"""
|
77 |
+
alignment = self.get_alignment_energies(
|
78 |
+
attention_hidden_state, processed_memory, attention_weights_cat)
|
79 |
+
|
80 |
+
if mask is not None:
|
81 |
+
alignment.data.masked_fill_(mask, self.score_mask_value)
|
82 |
+
|
83 |
+
attention_weights = F.softmax(alignment, dim=1)
|
84 |
+
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
|
85 |
+
attention_context = attention_context.squeeze(1)
|
86 |
+
|
87 |
+
return attention_context, attention_weights
|
88 |
+
|
89 |
+
|
90 |
+
class Prenet(nn.Module):
|
91 |
+
def __init__(self, in_dim, sizes):
|
92 |
+
super(Prenet, self).__init__()
|
93 |
+
in_sizes = [in_dim] + sizes[:-1]
|
94 |
+
self.layers = nn.ModuleList(
|
95 |
+
[LinearNorm(in_size, out_size, bias=False)
|
96 |
+
for (in_size, out_size) in zip(in_sizes, sizes)])
|
97 |
+
|
98 |
+
def forward(self, x):
|
99 |
+
for linear in self.layers:
|
100 |
+
x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
|
101 |
+
return x
|
102 |
+
|
103 |
+
|
104 |
+
class Postnet(nn.Module):
|
105 |
+
"""Postnet
|
106 |
+
- Five 1-d convolution with 512 channels and kernel size 5
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(self, hparams):
|
110 |
+
super(Postnet, self).__init__()
|
111 |
+
self.convolutions = nn.ModuleList()
|
112 |
+
|
113 |
+
self.convolutions.append(
|
114 |
+
nn.Sequential(
|
115 |
+
ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim,
|
116 |
+
kernel_size=hparams.postnet_kernel_size, stride=1,
|
117 |
+
padding=int((hparams.postnet_kernel_size - 1) / 2),
|
118 |
+
dilation=1, w_init_gain='tanh'),
|
119 |
+
nn.BatchNorm1d(hparams.postnet_embedding_dim))
|
120 |
+
)
|
121 |
+
|
122 |
+
for i in range(1, hparams.postnet_n_convolutions - 1):
|
123 |
+
self.convolutions.append(
|
124 |
+
nn.Sequential(
|
125 |
+
ConvNorm(hparams.postnet_embedding_dim,
|
126 |
+
hparams.postnet_embedding_dim,
|
127 |
+
kernel_size=hparams.postnet_kernel_size, stride=1,
|
128 |
+
padding=int((hparams.postnet_kernel_size - 1) / 2),
|
129 |
+
dilation=1, w_init_gain='tanh'),
|
130 |
+
nn.BatchNorm1d(hparams.postnet_embedding_dim))
|
131 |
+
)
|
132 |
+
|
133 |
+
self.convolutions.append(
|
134 |
+
nn.Sequential(
|
135 |
+
ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels,
|
136 |
+
kernel_size=hparams.postnet_kernel_size, stride=1,
|
137 |
+
padding=int((hparams.postnet_kernel_size - 1) / 2),
|
138 |
+
dilation=1, w_init_gain='linear'),
|
139 |
+
nn.BatchNorm1d(hparams.n_mel_channels))
|
140 |
+
)
|
141 |
+
|
142 |
+
def forward(self, x):
|
143 |
+
for i in range(len(self.convolutions) - 1):
|
144 |
+
x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
|
145 |
+
x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
|
146 |
+
|
147 |
+
return x
|
148 |
+
|
149 |
+
|
150 |
+
class Encoder(nn.Module):
|
151 |
+
"""Encoder module:
|
152 |
+
- Three 1-d convolution banks
|
153 |
+
- Bidirectional LSTM
|
154 |
+
"""
|
155 |
+
def __init__(self, hparams):
|
156 |
+
super(Encoder, self).__init__()
|
157 |
+
|
158 |
+
convolutions = []
|
159 |
+
for _ in range(hparams.encoder_n_convolutions):
|
160 |
+
conv_layer = nn.Sequential(
|
161 |
+
ConvNorm(hparams.encoder_embedding_dim,
|
162 |
+
hparams.encoder_embedding_dim,
|
163 |
+
kernel_size=hparams.encoder_kernel_size, stride=1,
|
164 |
+
padding=int((hparams.encoder_kernel_size - 1) / 2),
|
165 |
+
dilation=1, w_init_gain='relu'),
|
166 |
+
nn.BatchNorm1d(hparams.encoder_embedding_dim))
|
167 |
+
convolutions.append(conv_layer)
|
168 |
+
self.convolutions = nn.ModuleList(convolutions)
|
169 |
+
|
170 |
+
self.lstm = nn.LSTM(hparams.encoder_embedding_dim,
|
171 |
+
int(hparams.encoder_embedding_dim / 2), 1,
|
172 |
+
batch_first=True, bidirectional=True)
|
173 |
+
|
174 |
+
def forward(self, x, input_lengths):
|
175 |
+
for conv in self.convolutions:
|
176 |
+
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
|
177 |
+
|
178 |
+
x = x.transpose(1, 2)
|
179 |
+
|
180 |
+
# pytorch tensor are not reversible, hence the conversion
|
181 |
+
input_lengths = input_lengths.cpu().numpy()
|
182 |
+
x = nn.utils.rnn.pack_padded_sequence(
|
183 |
+
x, input_lengths, batch_first=True)
|
184 |
+
|
185 |
+
self.lstm.flatten_parameters()
|
186 |
+
outputs, _ = self.lstm(x)
|
187 |
+
|
188 |
+
outputs, _ = nn.utils.rnn.pad_packed_sequence(
|
189 |
+
outputs, batch_first=True)
|
190 |
+
|
191 |
+
return outputs
|
192 |
+
|
193 |
+
def inference(self, x):
|
194 |
+
for conv in self.convolutions:
|
195 |
+
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
|
196 |
+
|
197 |
+
x = x.transpose(1, 2)
|
198 |
+
|
199 |
+
self.lstm.flatten_parameters()
|
200 |
+
outputs, _ = self.lstm(x)
|
201 |
+
|
202 |
+
return outputs
|
203 |
+
|
204 |
+
|
205 |
+
class AudioEncoder(nn.Module):
|
206 |
+
def __init__(self, hparams):
|
207 |
+
super(AudioEncoder, self).__init__()
|
208 |
+
|
209 |
+
assert hparams.lat_dim > 0
|
210 |
+
|
211 |
+
convolutions = []
|
212 |
+
inp_dim = hparams.n_mel_channels
|
213 |
+
for _ in range(hparams.lat_n_convolutions):
|
214 |
+
conv_layer = nn.Sequential(
|
215 |
+
ConvNorm(inp_dim, hparams.lat_n_filters,
|
216 |
+
kernel_size=hparams.lat_kernel_size, stride=1,
|
217 |
+
padding=int((hparams.lat_kernel_size - 1) / 2),
|
218 |
+
dilation=1, w_init_gain='tanh'),
|
219 |
+
nn.BatchNorm1d(hparams.lat_n_filters))
|
220 |
+
inp_dim = hparams.lat_n_filters
|
221 |
+
convolutions.append(conv_layer)
|
222 |
+
self.convolutions = nn.ModuleList(convolutions)
|
223 |
+
|
224 |
+
self.lstm = nn.LSTM(hparams.lat_n_filters,
|
225 |
+
int(hparams.lat_n_filters / 2),
|
226 |
+
hparams.lat_n_blstms, batch_first=True,
|
227 |
+
bidirectional=True)
|
228 |
+
self.pool = GlobalAvgPool()
|
229 |
+
|
230 |
+
self.mu_proj = LinearNorm(hparams.lat_n_filters, hparams.lat_dim)
|
231 |
+
self.logvar_proj = LinearNorm(hparams.lat_n_filters, hparams.lat_dim)
|
232 |
+
self.lat_dim = hparams.lat_dim
|
233 |
+
|
234 |
+
def forward(self, x, lengths):
|
235 |
+
"""
|
236 |
+
Args:
|
237 |
+
x (torch.Tensor): (B, F, T)
|
238 |
+
"""
|
239 |
+
|
240 |
+
for conv in self.convolutions:
|
241 |
+
x = F.dropout(F.tanh(conv(x)), 0.5, self.training)
|
242 |
+
|
243 |
+
x = x.transpose(1, 2) # (B, T, D)
|
244 |
+
|
245 |
+
# x may not be sorted by length. Sort->process->unsort
|
246 |
+
max_len = x.size(1)
|
247 |
+
assert max_len == torch.max(lengths).item()
|
248 |
+
|
249 |
+
lengths, perm_idx = lengths.sort(0, descending=True)
|
250 |
+
x = x[perm_idx]
|
251 |
+
x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True)
|
252 |
+
|
253 |
+
self.lstm.flatten_parameters()
|
254 |
+
outputs, _ = self.lstm(x)
|
255 |
+
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
|
256 |
+
|
257 |
+
_, unperm_idx = perm_idx.sort(0)
|
258 |
+
outputs = outputs[unperm_idx] # (B, T, D)
|
259 |
+
lengths = lengths[unperm_idx] # (B, T, D)
|
260 |
+
|
261 |
+
outputs = self.pool(outputs, lengths) # (B, D)
|
262 |
+
|
263 |
+
mu = self.mu_proj(outputs)
|
264 |
+
logvar = self.logvar_proj(outputs)
|
265 |
+
z = distr.Normal(mu, logvar).rsample()
|
266 |
+
return z, mu, logvar
|
267 |
+
|
268 |
+
|
269 |
+
class Decoder(nn.Module):
|
270 |
+
def __init__(self, hparams):
|
271 |
+
super(Decoder, self).__init__()
|
272 |
+
self.n_mel_channels = hparams.n_mel_channels
|
273 |
+
self.n_frames_per_step = hparams.n_frames_per_step
|
274 |
+
self.encoder_embedding_dim = hparams.encoder_embedding_dim
|
275 |
+
self.obs_dim = hparams.obs_dim
|
276 |
+
self.lat_dim = hparams.lat_dim
|
277 |
+
self.attention_rnn_dim = hparams.attention_rnn_dim
|
278 |
+
self.decoder_rnn_dim = hparams.decoder_rnn_dim
|
279 |
+
self.prenet_dim = hparams.prenet_dim
|
280 |
+
self.max_decoder_steps = hparams.max_decoder_steps
|
281 |
+
self.gate_threshold = hparams.gate_threshold
|
282 |
+
self.p_attention_dropout = hparams.p_attention_dropout
|
283 |
+
self.p_decoder_dropout = hparams.p_decoder_dropout
|
284 |
+
|
285 |
+
self.prenet = Prenet(
|
286 |
+
hparams.n_mel_channels * hparams.n_frames_per_step,
|
287 |
+
[hparams.prenet_dim, hparams.prenet_dim])
|
288 |
+
|
289 |
+
self.attention_rnn = nn.LSTMCell(
|
290 |
+
hparams.prenet_dim + hparams.encoder_embedding_dim,
|
291 |
+
hparams.attention_rnn_dim)
|
292 |
+
|
293 |
+
self.attention_layer = Attention(
|
294 |
+
hparams.attention_rnn_dim, hparams.encoder_embedding_dim,
|
295 |
+
hparams.attention_dim, hparams.attention_location_n_filters,
|
296 |
+
hparams.attention_location_kernel_size)
|
297 |
+
|
298 |
+
encoder_tot_dim = (hparams.encoder_embedding_dim + \
|
299 |
+
hparams.lat_dim + hparams.obs_dim)
|
300 |
+
self.decoder_rnn = nn.LSTMCell(
|
301 |
+
hparams.attention_rnn_dim + encoder_tot_dim,
|
302 |
+
hparams.decoder_rnn_dim, 1)
|
303 |
+
|
304 |
+
self.linear_projection = LinearNorm(
|
305 |
+
hparams.decoder_rnn_dim + encoder_tot_dim,
|
306 |
+
hparams.n_mel_channels * hparams.n_frames_per_step)
|
307 |
+
|
308 |
+
self.gate_layer = LinearNorm(
|
309 |
+
hparams.decoder_rnn_dim + encoder_tot_dim, 1,
|
310 |
+
bias=True, w_init_gain='sigmoid')
|
311 |
+
|
312 |
+
def get_go_frame(self, memory):
|
313 |
+
""" Gets all zeros frames to use as first decoder input
|
314 |
+
PARAMS
|
315 |
+
------
|
316 |
+
memory: decoder outputs
|
317 |
+
|
318 |
+
RETURNS
|
319 |
+
-------
|
320 |
+
decoder_input: all zeros frames
|
321 |
+
"""
|
322 |
+
B = memory.size(0)
|
323 |
+
decoder_input = Variable(memory.data.new(
|
324 |
+
B, self.n_mel_channels * self.n_frames_per_step).zero_())
|
325 |
+
return decoder_input
|
326 |
+
|
327 |
+
def initialize_decoder_states(self, memory, obs_and_lat, mask):
|
328 |
+
""" Initializes attention rnn states, decoder rnn states, attention
|
329 |
+
weights, attention cumulative weights, attention context, stores memory
|
330 |
+
and stores processed memory
|
331 |
+
PARAMS
|
332 |
+
------
|
333 |
+
memory: Encoder outputs
|
334 |
+
obs_and_lat: Observed and latent attribute embeddings
|
335 |
+
mask: Mask for padded data if training, expects None for inference
|
336 |
+
"""
|
337 |
+
B = memory.size(0)
|
338 |
+
MAX_TIME = memory.size(1)
|
339 |
+
|
340 |
+
self.attention_hidden = Variable(memory.data.new(
|
341 |
+
B, self.attention_rnn_dim).zero_())
|
342 |
+
self.attention_cell = Variable(memory.data.new(
|
343 |
+
B, self.attention_rnn_dim).zero_())
|
344 |
+
|
345 |
+
self.decoder_hidden = Variable(memory.data.new(
|
346 |
+
B, self.decoder_rnn_dim).zero_())
|
347 |
+
self.decoder_cell = Variable(memory.data.new(
|
348 |
+
B, self.decoder_rnn_dim).zero_())
|
349 |
+
|
350 |
+
self.attention_weights = Variable(memory.data.new(
|
351 |
+
B, MAX_TIME).zero_())
|
352 |
+
self.attention_weights_cum = Variable(memory.data.new(
|
353 |
+
B, MAX_TIME).zero_())
|
354 |
+
self.attention_context = Variable(memory.data.new(
|
355 |
+
B, self.encoder_embedding_dim).zero_())
|
356 |
+
|
357 |
+
self.memory = memory
|
358 |
+
self.processed_memory = self.attention_layer.memory_layer(memory)
|
359 |
+
self.obs_and_lat = obs_and_lat
|
360 |
+
self.mask = mask
|
361 |
+
|
362 |
+
def parse_decoder_inputs(self, decoder_inputs):
|
363 |
+
""" Prepares decoder inputs, i.e. mel outputs
|
364 |
+
PARAMS
|
365 |
+
------
|
366 |
+
decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs
|
367 |
+
|
368 |
+
RETURNS
|
369 |
+
-------
|
370 |
+
inputs: processed decoder inputs
|
371 |
+
|
372 |
+
"""
|
373 |
+
# (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels)
|
374 |
+
decoder_inputs = decoder_inputs.transpose(1, 2)
|
375 |
+
decoder_inputs = decoder_inputs.view(
|
376 |
+
decoder_inputs.size(0),
|
377 |
+
int(decoder_inputs.size(1)/self.n_frames_per_step), -1)
|
378 |
+
# (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels)
|
379 |
+
decoder_inputs = decoder_inputs.transpose(0, 1)
|
380 |
+
return decoder_inputs
|
381 |
+
|
382 |
+
def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):
|
383 |
+
""" Prepares decoder outputs for output
|
384 |
+
PARAMS
|
385 |
+
------
|
386 |
+
mel_outputs:
|
387 |
+
gate_outputs: gate output energies
|
388 |
+
alignments:
|
389 |
+
|
390 |
+
RETURNS
|
391 |
+
-------
|
392 |
+
mel_outputs:
|
393 |
+
gate_outpust: gate output energies
|
394 |
+
alignments:
|
395 |
+
"""
|
396 |
+
# (T_out, B) -> (B, T_out)
|
397 |
+
alignments = torch.stack(alignments).transpose(0, 1)
|
398 |
+
# (T_out, B) -> (B, T_out)
|
399 |
+
gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
|
400 |
+
gate_outputs = gate_outputs.contiguous()
|
401 |
+
# (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
|
402 |
+
mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
|
403 |
+
# decouple frames per step
|
404 |
+
mel_outputs = mel_outputs.view(
|
405 |
+
mel_outputs.size(0), -1, self.n_mel_channels)
|
406 |
+
# (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
|
407 |
+
mel_outputs = mel_outputs.transpose(1, 2)
|
408 |
+
|
409 |
+
return mel_outputs, gate_outputs, alignments
|
410 |
+
|
411 |
+
def decode(self, decoder_input):
|
412 |
+
""" Decoder step using stored states, attention and memory
|
413 |
+
PARAMS
|
414 |
+
------
|
415 |
+
decoder_input: previous mel output
|
416 |
+
|
417 |
+
RETURNS
|
418 |
+
-------
|
419 |
+
mel_output:
|
420 |
+
gate_output: gate output energies
|
421 |
+
attention_weights:
|
422 |
+
"""
|
423 |
+
cell_input = torch.cat((decoder_input, self.attention_context), -1)
|
424 |
+
self.attention_hidden, self.attention_cell = self.attention_rnn(
|
425 |
+
cell_input, (self.attention_hidden, self.attention_cell))
|
426 |
+
self.attention_hidden = F.dropout(
|
427 |
+
self.attention_hidden, self.p_attention_dropout, self.training)
|
428 |
+
|
429 |
+
attention_weights_cat = torch.cat(
|
430 |
+
(self.attention_weights.unsqueeze(1),
|
431 |
+
self.attention_weights_cum.unsqueeze(1)), dim=1)
|
432 |
+
self.attention_context, self.attention_weights = self.attention_layer(
|
433 |
+
self.attention_hidden, self.memory, self.processed_memory,
|
434 |
+
attention_weights_cat, self.mask)
|
435 |
+
|
436 |
+
self.attention_weights_cum += self.attention_weights
|
437 |
+
decoder_input = torch.cat(
|
438 |
+
(self.attention_hidden, self.attention_context), -1)
|
439 |
+
if self.obs_and_lat is not None:
|
440 |
+
decoder_input = torch.cat((decoder_input, self.obs_and_lat), -1)
|
441 |
+
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
|
442 |
+
decoder_input, (self.decoder_hidden, self.decoder_cell))
|
443 |
+
self.decoder_hidden = F.dropout(
|
444 |
+
self.decoder_hidden, self.p_decoder_dropout, self.training)
|
445 |
+
|
446 |
+
decoder_hidden_attention_context = torch.cat(
|
447 |
+
(self.decoder_hidden, self.attention_context), dim=1)
|
448 |
+
if self.obs_and_lat is not None:
|
449 |
+
decoder_hidden_attention_context = torch.cat(
|
450 |
+
(decoder_hidden_attention_context, self.obs_and_lat), dim=1)
|
451 |
+
decoder_output = self.linear_projection(
|
452 |
+
decoder_hidden_attention_context)
|
453 |
+
|
454 |
+
gate_prediction = self.gate_layer(decoder_hidden_attention_context)
|
455 |
+
return decoder_output, gate_prediction, self.attention_weights
|
456 |
+
|
457 |
+
def forward(self, memory, obs_and_lat, decoder_inputs, memory_lengths):
|
458 |
+
""" Decoder forward pass for training
|
459 |
+
PARAMS
|
460 |
+
------
|
461 |
+
memory: Encoder outputs
|
462 |
+
obs_and_lat: Observed and latent attribute embeddings
|
463 |
+
decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
|
464 |
+
memory_lengths: Encoder output lengths for attention masking.
|
465 |
+
|
466 |
+
RETURNS
|
467 |
+
-------
|
468 |
+
mel_outputs: mel outputs from the decoder
|
469 |
+
gate_outputs: gate outputs from the decoder
|
470 |
+
alignments: sequence of attention weights from the decoder
|
471 |
+
"""
|
472 |
+
|
473 |
+
decoder_input = self.get_go_frame(memory).unsqueeze(0)
|
474 |
+
decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
|
475 |
+
decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
|
476 |
+
decoder_inputs = self.prenet(decoder_inputs)
|
477 |
+
|
478 |
+
self.initialize_decoder_states(
|
479 |
+
memory, obs_and_lat, mask=~get_mask_from_lengths(memory_lengths))
|
480 |
+
|
481 |
+
mel_outputs, gate_outputs, alignments = [], [], []
|
482 |
+
while len(mel_outputs) < decoder_inputs.size(0) - 1:
|
483 |
+
decoder_input = decoder_inputs[len(mel_outputs)]
|
484 |
+
mel_output, gate_output, attention_weights = self.decode(
|
485 |
+
decoder_input)
|
486 |
+
mel_outputs += [mel_output.squeeze(1)]
|
487 |
+
gate_outputs += [gate_output.squeeze()]
|
488 |
+
alignments += [attention_weights]
|
489 |
+
|
490 |
+
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
|
491 |
+
mel_outputs, gate_outputs, alignments)
|
492 |
+
|
493 |
+
return mel_outputs, gate_outputs, alignments
|
494 |
+
|
495 |
+
def inference(self, memory, obs_and_lat, ret_has_eos=False):
|
496 |
+
""" Decoder inference
|
497 |
+
PARAMS
|
498 |
+
------
|
499 |
+
memory: Encoder outputs
|
500 |
+
obs_and_lat: Observed and latent attribute embeddings
|
501 |
+
|
502 |
+
RETURNS
|
503 |
+
-------
|
504 |
+
mel_outputs: mel outputs from the decoder
|
505 |
+
gate_outputs: gate outputs from the decoder
|
506 |
+
alignments: sequence of attention weights from the decoder
|
507 |
+
"""
|
508 |
+
decoder_input = self.get_go_frame(memory)
|
509 |
+
|
510 |
+
self.initialize_decoder_states(memory, obs_and_lat, mask=None)
|
511 |
+
|
512 |
+
mel_outputs, gate_outputs, alignments = [], [], []
|
513 |
+
has_eos = False
|
514 |
+
while True:
|
515 |
+
decoder_input = self.prenet(decoder_input)
|
516 |
+
mel_output, gate_output, alignment = self.decode(decoder_input)
|
517 |
+
|
518 |
+
mel_outputs += [mel_output.squeeze(1)]
|
519 |
+
gate_outputs += [gate_output]
|
520 |
+
alignments += [alignment]
|
521 |
+
|
522 |
+
if torch.sigmoid(gate_output.data) > self.gate_threshold:
|
523 |
+
has_eos = True
|
524 |
+
break
|
525 |
+
elif len(mel_outputs) == self.max_decoder_steps:
|
526 |
+
# print("Warning! Reached max decoder steps")
|
527 |
+
break
|
528 |
+
|
529 |
+
decoder_input = mel_output
|
530 |
+
|
531 |
+
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
|
532 |
+
mel_outputs, gate_outputs, alignments)
|
533 |
+
|
534 |
+
if ret_has_eos:
|
535 |
+
return mel_outputs, gate_outputs, alignments, has_eos
|
536 |
+
else:
|
537 |
+
return mel_outputs, gate_outputs, alignments
|
538 |
+
|
539 |
+
|
540 |
+
class Tacotron2(nn.Module):
|
541 |
+
def __init__(self, hparams):
|
542 |
+
super(Tacotron2, self).__init__()
|
543 |
+
self.mask_padding = hparams.mask_padding
|
544 |
+
self.fp16_run = hparams.fp16_run
|
545 |
+
self.n_mel_channels = hparams.n_mel_channels
|
546 |
+
self.n_frames_per_step = hparams.n_frames_per_step
|
547 |
+
|
548 |
+
# initialize text encoder embedding
|
549 |
+
self.embedding = nn.Embedding(
|
550 |
+
hparams.n_symbols, hparams.symbols_embedding_dim)
|
551 |
+
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
|
552 |
+
val = sqrt(3.0) * std # uniform bounds for std
|
553 |
+
self.embedding.weight.data.uniform_(-val, val)
|
554 |
+
|
555 |
+
# initialize observed attribute embedding
|
556 |
+
self.obs_embedding = None
|
557 |
+
if hparams.obs_dim > 0:
|
558 |
+
self.obs_embedding = nn.Embedding(
|
559 |
+
hparams.obs_n_class, hparams.obs_dim)
|
560 |
+
std = sqrt(2.0 / (hparams.obs_n_class + hparams.obs_dim))
|
561 |
+
val = sqrt(3.0) * std # uniform bounds for std
|
562 |
+
self.obs_embedding.weight.data.uniform_(-val, val)
|
563 |
+
|
564 |
+
self.encoder = Encoder(hparams)
|
565 |
+
self.decoder = Decoder(hparams)
|
566 |
+
self.postnet = Postnet(hparams)
|
567 |
+
|
568 |
+
self.lat_encoder = None
|
569 |
+
if hparams.lat_dim > 0:
|
570 |
+
self.lat_encoder = AudioEncoder(hparams)
|
571 |
+
|
572 |
+
def parse_batch(self, batch):
|
573 |
+
(text_padded, input_lengths, obs_labels,
|
574 |
+
mel_padded, gate_padded, output_lengths) = batch
|
575 |
+
text_padded = to_gpu(text_padded).long()
|
576 |
+
input_lengths = to_gpu(input_lengths).long()
|
577 |
+
obs_labels = to_gpu(obs_labels).long()
|
578 |
+
max_len = torch.max(input_lengths.data).item()
|
579 |
+
mel_padded = to_gpu(mel_padded).float()
|
580 |
+
gate_padded = to_gpu(gate_padded).float()
|
581 |
+
output_lengths = to_gpu(output_lengths).long()
|
582 |
+
|
583 |
+
return (
|
584 |
+
(text_padded, input_lengths, obs_labels,
|
585 |
+
mel_padded, max_len, output_lengths),
|
586 |
+
(mel_padded, gate_padded))
|
587 |
+
|
588 |
+
def parse_output(self, outputs, output_lengths=None):
|
589 |
+
if self.mask_padding and output_lengths is not None:
|
590 |
+
mask = ~get_mask_from_lengths(output_lengths)
|
591 |
+
mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
|
592 |
+
mask = mask.permute(1, 0, 2)
|
593 |
+
|
594 |
+
outputs[0].data.masked_fill_(mask, 0.0)
|
595 |
+
outputs[1].data.masked_fill_(mask, 0.0)
|
596 |
+
outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
|
597 |
+
|
598 |
+
return outputs
|
599 |
+
|
600 |
+
def forward(self, inputs):
|
601 |
+
(text_inputs, text_lengths, obs_labels,
|
602 |
+
mels, max_len, output_lengths) = inputs
|
603 |
+
text_lengths, output_lengths = text_lengths.data, output_lengths.data
|
604 |
+
|
605 |
+
embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
|
606 |
+
|
607 |
+
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
608 |
+
|
609 |
+
obs = None
|
610 |
+
if self.obs_embedding is not None:
|
611 |
+
obs = self.obs_embedding(obs_labels)
|
612 |
+
|
613 |
+
lat, lat_mu, lat_logvar = None, None, None
|
614 |
+
if self.lat_encoder is not None:
|
615 |
+
(lat, lat_mu, lat_logvar) = self.lat_encoder(mels, output_lengths)
|
616 |
+
|
617 |
+
obs_and_lat = [x for x in [obs, lat] if x is not None]
|
618 |
+
if bool(obs_and_lat):
|
619 |
+
obs_and_lat = torch.cat(obs_and_lat, dim=-1)
|
620 |
+
else:
|
621 |
+
obs_and_lat = None
|
622 |
+
|
623 |
+
mel_outputs, gate_outputs, alignments = self.decoder(
|
624 |
+
encoder_outputs, obs_and_lat, mels, memory_lengths=text_lengths)
|
625 |
+
|
626 |
+
mel_outputs_postnet = self.postnet(mel_outputs)
|
627 |
+
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
628 |
+
|
629 |
+
return self.parse_output(
|
630 |
+
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments,
|
631 |
+
lat_mu, lat_logvar],
|
632 |
+
output_lengths)
|
633 |
+
|
634 |
+
def inference(self, inputs, obs_labels=None, lat=None, ret_has_eos=False):
|
635 |
+
embedded_inputs = self.embedding(inputs).transpose(1, 2)
|
636 |
+
encoder_outputs = self.encoder.inference(embedded_inputs)
|
637 |
+
|
638 |
+
if obs_labels is None:
|
639 |
+
obs_labels = torch.LongTensor(len(inputs))
|
640 |
+
obs_labels = obs_labels.to(inputs.device).zero_()
|
641 |
+
|
642 |
+
obs = None
|
643 |
+
if self.obs_embedding is not None:
|
644 |
+
obs = self.obs_embedding(obs_labels)
|
645 |
+
|
646 |
+
if self.lat_encoder is not None:
|
647 |
+
if lat is None:
|
648 |
+
lat = torch.FloatTensor(len(inputs), self.lat_encoder.lat_dim)
|
649 |
+
lat = lat.to(inputs.device).zero_().type(encoder_outputs.type())
|
650 |
+
|
651 |
+
obs_and_lat = [x for x in [obs, lat] if x is not None]
|
652 |
+
if bool(obs_and_lat):
|
653 |
+
obs_and_lat = torch.cat(obs_and_lat, dim=-1)
|
654 |
+
else:
|
655 |
+
obs_and_lat = None
|
656 |
+
|
657 |
+
mel_outputs, gate_outputs, alignments, has_eos = self.decoder.inference(
|
658 |
+
encoder_outputs, obs_and_lat, ret_has_eos=True)
|
659 |
+
|
660 |
+
mel_outputs_postnet = self.postnet(mel_outputs)
|
661 |
+
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
662 |
+
|
663 |
+
outputs = self.parse_output(
|
664 |
+
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments])
|
665 |
+
|
666 |
+
if ret_has_eos:
|
667 |
+
return outputs + [has_eos]
|
668 |
+
else:
|
669 |
+
return outputs
|
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/numbers.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/keithito/tacotron """
|
2 |
+
|
3 |
+
import inflect
|
4 |
+
import re
|
5 |
+
|
6 |
+
|
7 |
+
_inflect = inflect.engine()
|
8 |
+
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
9 |
+
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
|
10 |
+
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
|
11 |
+
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
|
12 |
+
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
|
13 |
+
_number_re = re.compile(r'[0-9]+')
|
14 |
+
|
15 |
+
|
16 |
+
def _remove_commas(m):
|
17 |
+
return m.group(1).replace(',', '')
|
18 |
+
|
19 |
+
|
20 |
+
def _expand_decimal_point(m):
|
21 |
+
return m.group(1).replace('.', ' point ')
|
22 |
+
|
23 |
+
|
24 |
+
def _expand_dollars(m):
|
25 |
+
match = m.group(1)
|
26 |
+
parts = match.split('.')
|
27 |
+
if len(parts) > 2:
|
28 |
+
return match + ' dollars' # Unexpected format
|
29 |
+
dollars = int(parts[0]) if parts[0] else 0
|
30 |
+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
31 |
+
if dollars and cents:
|
32 |
+
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
33 |
+
cent_unit = 'cent' if cents == 1 else 'cents'
|
34 |
+
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
|
35 |
+
elif dollars:
|
36 |
+
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
37 |
+
return '%s %s' % (dollars, dollar_unit)
|
38 |
+
elif cents:
|
39 |
+
cent_unit = 'cent' if cents == 1 else 'cents'
|
40 |
+
return '%s %s' % (cents, cent_unit)
|
41 |
+
else:
|
42 |
+
return 'zero dollars'
|
43 |
+
|
44 |
+
|
45 |
+
def _expand_ordinal(m):
|
46 |
+
return _inflect.number_to_words(m.group(0))
|
47 |
+
|
48 |
+
|
49 |
+
def _expand_number(m):
|
50 |
+
num = int(m.group(0))
|
51 |
+
if num > 1000 and num < 3000:
|
52 |
+
if num == 2000:
|
53 |
+
return 'two thousand'
|
54 |
+
elif num > 2000 and num < 2010:
|
55 |
+
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
56 |
+
elif num % 100 == 0:
|
57 |
+
return _inflect.number_to_words(num // 100) + ' hundred'
|
58 |
+
else:
|
59 |
+
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
|
60 |
+
else:
|
61 |
+
return _inflect.number_to_words(num, andword='')
|
62 |
+
|
63 |
+
|
64 |
+
def normalize_numbers(text):
|
65 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
66 |
+
text = re.sub(_pounds_re, r'\1 pounds', text)
|
67 |
+
text = re.sub(_dollars_re, _expand_dollars, text)
|
68 |
+
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
69 |
+
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
70 |
+
text = re.sub(_number_re, _expand_number, text)
|
71 |
+
return text
|
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/stft.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
BSD 3-Clause License
|
3 |
+
|
4 |
+
Copyright (c) 2017, Prem Seetharaman
|
5 |
+
All rights reserved.
|
6 |
+
|
7 |
+
* Redistribution and use in source and binary forms, with or without
|
8 |
+
modification, are permitted provided that the following conditions are met:
|
9 |
+
|
10 |
+
* Redistributions of source code must retain the above copyright notice,
|
11 |
+
this list of conditions and the following disclaimer.
|
12 |
+
|
13 |
+
* Redistributions in binary form must reproduce the above copyright notice, this
|
14 |
+
list of conditions and the following disclaimer in the
|
15 |
+
documentation and/or other materials provided with the distribution.
|
16 |
+
|
17 |
+
* Neither the name of the copyright holder nor the names of its
|
18 |
+
contributors may be used to endorse or promote products derived from this
|
19 |
+
software without specific prior written permission.
|
20 |
+
|
21 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
22 |
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
23 |
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
24 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
25 |
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
26 |
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
27 |
+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
28 |
+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
29 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
30 |
+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
31 |
+
"""
|
32 |
+
|
33 |
+
import torch
|
34 |
+
import numpy as np
|
35 |
+
import torch.nn.functional as F
|
36 |
+
from torch.autograd import Variable
|
37 |
+
from scipy.signal import get_window
|
38 |
+
from librosa.util import pad_center, tiny
|
39 |
+
from .audio_processing import window_sumsquare
|
40 |
+
|
41 |
+
|
42 |
+
class STFT(torch.nn.Module):
|
43 |
+
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
44 |
+
def __init__(self, filter_length=800, hop_length=200, win_length=800,
|
45 |
+
window='hann'):
|
46 |
+
super(STFT, self).__init__()
|
47 |
+
self.filter_length = filter_length
|
48 |
+
self.hop_length = hop_length
|
49 |
+
self.win_length = win_length
|
50 |
+
self.window = window
|
51 |
+
self.forward_transform = None
|
52 |
+
scale = self.filter_length / self.hop_length
|
53 |
+
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
54 |
+
|
55 |
+
cutoff = int((self.filter_length / 2 + 1))
|
56 |
+
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
|
57 |
+
np.imag(fourier_basis[:cutoff, :])])
|
58 |
+
|
59 |
+
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
60 |
+
inverse_basis = torch.FloatTensor(
|
61 |
+
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
|
62 |
+
|
63 |
+
if window is not None:
|
64 |
+
assert(filter_length >= win_length)
|
65 |
+
# get window and zero center pad it to filter_length
|
66 |
+
fft_window = get_window(window, win_length, fftbins=True)
|
67 |
+
fft_window = pad_center(fft_window, filter_length)
|
68 |
+
fft_window = torch.from_numpy(fft_window).float()
|
69 |
+
|
70 |
+
# window the bases
|
71 |
+
forward_basis *= fft_window
|
72 |
+
inverse_basis *= fft_window
|
73 |
+
|
74 |
+
self.register_buffer('forward_basis', forward_basis.float())
|
75 |
+
self.register_buffer('inverse_basis', inverse_basis.float())
|
76 |
+
|
77 |
+
def transform(self, input_data):
|
78 |
+
num_batches = input_data.size(0)
|
79 |
+
num_samples = input_data.size(1)
|
80 |
+
|
81 |
+
self.num_samples = num_samples
|
82 |
+
|
83 |
+
# similar to librosa, reflect-pad the input
|
84 |
+
input_data = input_data.view(num_batches, 1, num_samples)
|
85 |
+
input_data = F.pad(
|
86 |
+
input_data.unsqueeze(1),
|
87 |
+
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
|
88 |
+
mode='reflect')
|
89 |
+
input_data = input_data.squeeze(1)
|
90 |
+
|
91 |
+
forward_transform = F.conv1d(
|
92 |
+
input_data,
|
93 |
+
Variable(self.forward_basis, requires_grad=False),
|
94 |
+
stride=self.hop_length,
|
95 |
+
padding=0)
|
96 |
+
|
97 |
+
cutoff = int((self.filter_length / 2) + 1)
|
98 |
+
real_part = forward_transform[:, :cutoff, :]
|
99 |
+
imag_part = forward_transform[:, cutoff:, :]
|
100 |
+
|
101 |
+
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
102 |
+
phase = torch.autograd.Variable(
|
103 |
+
torch.atan2(imag_part.data, real_part.data))
|
104 |
+
|
105 |
+
return magnitude, phase
|
106 |
+
|
107 |
+
def inverse(self, magnitude, phase):
|
108 |
+
recombine_magnitude_phase = torch.cat(
|
109 |
+
[magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
|
110 |
+
|
111 |
+
inverse_transform = F.conv_transpose1d(
|
112 |
+
recombine_magnitude_phase,
|
113 |
+
Variable(self.inverse_basis, requires_grad=False),
|
114 |
+
stride=self.hop_length,
|
115 |
+
padding=0)
|
116 |
+
|
117 |
+
if self.window is not None:
|
118 |
+
window_sum = window_sumsquare(
|
119 |
+
self.window, magnitude.size(-1), hop_length=self.hop_length,
|
120 |
+
win_length=self.win_length, n_fft=self.filter_length,
|
121 |
+
dtype=np.float32)
|
122 |
+
# remove modulation effects
|
123 |
+
approx_nonzero_indices = torch.from_numpy(
|
124 |
+
np.where(window_sum > tiny(window_sum))[0])
|
125 |
+
window_sum = torch.autograd.Variable(
|
126 |
+
torch.from_numpy(window_sum), requires_grad=False)
|
127 |
+
window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
|
128 |
+
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
|
129 |
+
|
130 |
+
# scale by hop ratio
|
131 |
+
inverse_transform *= float(self.filter_length) / self.hop_length
|
132 |
+
|
133 |
+
inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
|
134 |
+
inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
|
135 |
+
|
136 |
+
return inverse_transform
|
137 |
+
|
138 |
+
def forward(self, input_data):
|
139 |
+
self.magnitude, self.phase = self.transform(input_data)
|
140 |
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
141 |
+
return reconstruction
|
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/symbols.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/keithito/tacotron """
|
2 |
+
|
3 |
+
'''
|
4 |
+
Defines the set of symbols used in text input to the model.
|
5 |
+
|
6 |
+
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
|
7 |
+
from . import cmudict
|
8 |
+
|
9 |
+
_pad = '_'
|
10 |
+
_punctuation = '!\'(),.:;? '
|
11 |
+
_special = '-'
|
12 |
+
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
13 |
+
|
14 |
+
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
15 |
+
_arpabet = ['@' + s for s in cmudict.valid_symbols]
|
16 |
+
|
17 |
+
# Export all symbols:
|
18 |
+
symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet
|
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/text.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/keithito/tacotron """
|
2 |
+
import numpy as np
|
3 |
+
import re
|
4 |
+
from . import cleaners
|
5 |
+
from .symbols import symbols
|
6 |
+
|
7 |
+
|
8 |
+
# Mappings from symbol to numeric ID and vice versa:
|
9 |
+
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
10 |
+
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
11 |
+
|
12 |
+
# Regular expression matching text enclosed in curly braces:
|
13 |
+
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
|
14 |
+
|
15 |
+
# Special symbols
|
16 |
+
SOS_TOK = '<s>'
|
17 |
+
EOS_TOK = '</s>'
|
18 |
+
|
19 |
+
def text_to_sequence(text, cleaner_names):
|
20 |
+
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
21 |
+
|
22 |
+
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
|
23 |
+
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
|
24 |
+
|
25 |
+
Args:
|
26 |
+
text: string to convert to a sequence
|
27 |
+
cleaner_names: names of the cleaner functions to run the text through
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
List of integers corresponding to the symbols in the text
|
31 |
+
'''
|
32 |
+
sequence = []
|
33 |
+
|
34 |
+
# Check for curly braces and treat their contents as ARPAbet:
|
35 |
+
while len(text):
|
36 |
+
m = _curly_re.match(text)
|
37 |
+
if not m:
|
38 |
+
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
|
39 |
+
break
|
40 |
+
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
|
41 |
+
sequence += _arpabet_to_sequence(m.group(2))
|
42 |
+
text = m.group(3)
|
43 |
+
|
44 |
+
return sequence
|
45 |
+
|
46 |
+
|
47 |
+
def sample_code_chunk(code, size):
|
48 |
+
assert(size > 0 and size <= len(code))
|
49 |
+
start = np.random.randint(len(code) - size + 1)
|
50 |
+
end = start + size
|
51 |
+
return code[start:end], start, end
|
52 |
+
|
53 |
+
|
54 |
+
def code_to_sequence(code, code_dict, collapse_code):
|
55 |
+
if collapse_code:
|
56 |
+
prev_c = None
|
57 |
+
sequence = []
|
58 |
+
for c in code:
|
59 |
+
if c in code_dict and c != prev_c:
|
60 |
+
sequence.append(code_dict[c])
|
61 |
+
prev_c = c
|
62 |
+
else:
|
63 |
+
sequence = [code_dict[c] for c in code if c in code_dict]
|
64 |
+
if len(sequence) < 0.95 * len(code):
|
65 |
+
print('WARNING : over 5%% codes are OOV')
|
66 |
+
|
67 |
+
return sequence
|
68 |
+
|
69 |
+
|
70 |
+
def sequence_to_text(sequence):
|
71 |
+
'''Converts a sequence of IDs back to a string'''
|
72 |
+
result = ''
|
73 |
+
for symbol_id in sequence:
|
74 |
+
if symbol_id in _id_to_symbol:
|
75 |
+
s = _id_to_symbol[symbol_id]
|
76 |
+
# Enclose ARPAbet back in curly braces:
|
77 |
+
if len(s) > 1 and s[0] == '@':
|
78 |
+
s = '{%s}' % s[1:]
|
79 |
+
result += s
|
80 |
+
return result.replace('}{', ' ')
|
81 |
+
|
82 |
+
|
83 |
+
def sequence_to_code(sequence, code_dict):
|
84 |
+
'''Analogous to sequence_to_text'''
|
85 |
+
id_to_code = {i: c for c, i in code_dict.items()}
|
86 |
+
return ' '.join([id_to_code[i] for i in sequence])
|
87 |
+
|
88 |
+
|
89 |
+
def _clean_text(text, cleaner_names):
|
90 |
+
for name in cleaner_names:
|
91 |
+
cleaner = getattr(cleaners, name)
|
92 |
+
if not cleaner:
|
93 |
+
raise Exception('Unknown cleaner: %s' % name)
|
94 |
+
text = cleaner(text)
|
95 |
+
return text
|
96 |
+
|
97 |
+
|
98 |
+
def _symbols_to_sequence(symbols):
|
99 |
+
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
|
100 |
+
|
101 |
+
|
102 |
+
def _arpabet_to_sequence(text):
|
103 |
+
return _symbols_to_sequence(['@' + s for s in text.split()])
|
104 |
+
|
105 |
+
|
106 |
+
def _should_keep_symbol(s):
|
107 |
+
return s in _symbol_to_id and s != '_' and s != '~'
|
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/utils.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import collections
|
7 |
+
import io
|
8 |
+
import json
|
9 |
+
import librosa
|
10 |
+
import numpy as np
|
11 |
+
import soundfile as sf
|
12 |
+
import time
|
13 |
+
import torch
|
14 |
+
from scipy.io.wavfile import read
|
15 |
+
from .text import SOS_TOK, EOS_TOK
|
16 |
+
|
17 |
+
|
18 |
+
def get_mask_from_lengths(lengths):
|
19 |
+
max_len = torch.max(lengths).item()
|
20 |
+
ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
|
21 |
+
mask = (ids < lengths.unsqueeze(1))
|
22 |
+
return mask
|
23 |
+
|
24 |
+
|
25 |
+
def load_wav_to_torch(full_path, sr=None):
|
26 |
+
data, sr = librosa.load(full_path, sr=sr)
|
27 |
+
data = np.clip(data, -1, 1) # potentially out of [-1, 1] due to resampling
|
28 |
+
data = data * 32768.0 # match values loaded by scipy
|
29 |
+
return torch.FloatTensor(data.astype(np.float32)), sr
|
30 |
+
|
31 |
+
|
32 |
+
def read_binary_audio(bin_data, tar_sr=None):
|
33 |
+
"""
|
34 |
+
read binary audio (`bytes` or `uint8` `numpy.ndarray`) to `float32`
|
35 |
+
`numpy.ndarray`
|
36 |
+
|
37 |
+
RETURNS:
|
38 |
+
data (np.ndarray) : audio of shape (n,) or (2, n)
|
39 |
+
tar_sr (int) : sample rate
|
40 |
+
"""
|
41 |
+
data, ori_sr = sf.read(io.BytesIO(bin_data), dtype='float32')
|
42 |
+
data = data.T
|
43 |
+
if (tar_sr is not None) and (ori_sr != tar_sr):
|
44 |
+
data = librosa.resample(data, ori_sr, tar_sr)
|
45 |
+
else:
|
46 |
+
tar_sr = ori_sr
|
47 |
+
data = np.clip(data, -1, 1)
|
48 |
+
data = data * 32768.0
|
49 |
+
return torch.FloatTensor(data.astype(np.float32)), tar_sr
|
50 |
+
|
51 |
+
|
52 |
+
def load_filepaths_and_text(filename):
|
53 |
+
with open(filename, encoding='utf-8') as f:
|
54 |
+
data = [json.loads(line.rstrip()) for line in f]
|
55 |
+
return data
|
56 |
+
|
57 |
+
|
58 |
+
def to_gpu(x):
|
59 |
+
x = x.contiguous()
|
60 |
+
|
61 |
+
if torch.cuda.is_available():
|
62 |
+
x = x.cuda(non_blocking=True)
|
63 |
+
return torch.autograd.Variable(x)
|
64 |
+
|
65 |
+
|
66 |
+
def load_code_dict(path, add_sos=False, add_eos=False):
|
67 |
+
if not path:
|
68 |
+
return {}
|
69 |
+
|
70 |
+
with open(path, 'r') as f:
|
71 |
+
codes = ['_'] + [line.rstrip() for line in f] # '_' for pad
|
72 |
+
code_dict = {c: i for i, c in enumerate(codes)}
|
73 |
+
|
74 |
+
if add_sos:
|
75 |
+
code_dict[SOS_TOK] = len(code_dict)
|
76 |
+
if add_eos:
|
77 |
+
code_dict[EOS_TOK] = len(code_dict)
|
78 |
+
assert(set(code_dict.values()) == set(range(len(code_dict))))
|
79 |
+
|
80 |
+
return code_dict
|
81 |
+
|
82 |
+
|
83 |
+
def load_obs_label_dict(path):
|
84 |
+
if not path:
|
85 |
+
return {}
|
86 |
+
with open(path, 'r') as f:
|
87 |
+
obs_labels = [line.rstrip() for line in f]
|
88 |
+
return {c: i for i, c in enumerate(obs_labels)}
|
89 |
+
|
90 |
+
|
91 |
+
# A simple timer class inspired from `tnt.TimeMeter`
|
92 |
+
class CudaTimer:
|
93 |
+
def __init__(self, keys):
|
94 |
+
self.keys = keys
|
95 |
+
self.reset()
|
96 |
+
|
97 |
+
def start(self, key):
|
98 |
+
s = torch.cuda.Event(enable_timing=True)
|
99 |
+
s.record()
|
100 |
+
self.start_events[key].append(s)
|
101 |
+
return self
|
102 |
+
|
103 |
+
def stop(self, key):
|
104 |
+
e = torch.cuda.Event(enable_timing=True)
|
105 |
+
e.record()
|
106 |
+
self.end_events[key].append(e)
|
107 |
+
return self
|
108 |
+
|
109 |
+
def reset(self):
|
110 |
+
self.start_events = collections.defaultdict(list)
|
111 |
+
self.end_events = collections.defaultdict(list)
|
112 |
+
self.running_times = collections.defaultdict(float)
|
113 |
+
self.n = collections.defaultdict(int)
|
114 |
+
return self
|
115 |
+
|
116 |
+
def value(self):
|
117 |
+
self._synchronize()
|
118 |
+
return {k: self.running_times[k] / self.n[k] for k in self.keys}
|
119 |
+
|
120 |
+
def _synchronize(self):
|
121 |
+
torch.cuda.synchronize()
|
122 |
+
for k in self.keys:
|
123 |
+
starts = self.start_events[k]
|
124 |
+
ends = self.end_events[k]
|
125 |
+
if len(starts) == 0:
|
126 |
+
raise ValueError("Trying to divide by zero in TimeMeter")
|
127 |
+
if len(ends) != len(starts):
|
128 |
+
raise ValueError("Call stop before checking value!")
|
129 |
+
time = 0
|
130 |
+
for start, end in zip(starts, ends):
|
131 |
+
time += start.elapsed_time(end)
|
132 |
+
self.running_times[k] += time * 1e-3
|
133 |
+
self.n[k] += len(starts)
|
134 |
+
self.start_events = collections.defaultdict(list)
|
135 |
+
self.end_events = collections.defaultdict(list)
|
136 |
+
|
137 |
+
|
138 |
+
# Used to measure the time taken for multiple events
|
139 |
+
class Timer:
|
140 |
+
def __init__(self, keys):
|
141 |
+
self.keys = keys
|
142 |
+
self.n = {}
|
143 |
+
self.running_time = {}
|
144 |
+
self.total_time = {}
|
145 |
+
self.reset()
|
146 |
+
|
147 |
+
def start(self, key):
|
148 |
+
self.running_time[key] = time.time()
|
149 |
+
return self
|
150 |
+
|
151 |
+
def stop(self, key):
|
152 |
+
self.total_time[key] = time.time() - self.running_time[key]
|
153 |
+
self.n[key] += 1
|
154 |
+
self.running_time[key] = None
|
155 |
+
return self
|
156 |
+
|
157 |
+
def reset(self):
|
158 |
+
for k in self.keys:
|
159 |
+
self.total_time[k] = 0
|
160 |
+
self.running_time[k] = None
|
161 |
+
self.n[k] = 0
|
162 |
+
return self
|
163 |
+
|
164 |
+
def value(self):
|
165 |
+
vals = {}
|
166 |
+
for k in self.keys:
|
167 |
+
if self.n[k] == 0:
|
168 |
+
raise ValueError("Trying to divide by zero in TimeMeter")
|
169 |
+
else:
|
170 |
+
vals[k] = self.total_time[k] / self.n[k]
|
171 |
+
return vals
|
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/waveglow_denoiser.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import sys
|
2 |
+
# sys.path.append('tacotron2')
|
3 |
+
import torch
|
4 |
+
from .layers import STFT
|
5 |
+
|
6 |
+
|
7 |
+
class Denoiser(torch.nn.Module):
|
8 |
+
""" Removes model bias from audio produced with waveglow """
|
9 |
+
|
10 |
+
def __init__(self, waveglow, filter_length=1024, n_overlap=4,
|
11 |
+
win_length=1024, mode='zeros'):
|
12 |
+
super(Denoiser, self).__init__()
|
13 |
+
self.stft = STFT(filter_length=filter_length,
|
14 |
+
hop_length=int(filter_length/n_overlap),
|
15 |
+
win_length=win_length).cuda()
|
16 |
+
if mode == 'zeros':
|
17 |
+
mel_input = torch.zeros(
|
18 |
+
(1, 80, 88),
|
19 |
+
dtype=waveglow.upsample.weight.dtype,
|
20 |
+
device=waveglow.upsample.weight.device)
|
21 |
+
elif mode == 'normal':
|
22 |
+
mel_input = torch.randn(
|
23 |
+
(1, 80, 88),
|
24 |
+
dtype=waveglow.upsample.weight.dtype,
|
25 |
+
device=waveglow.upsample.weight.device)
|
26 |
+
else:
|
27 |
+
raise Exception("Mode {} if not supported".format(mode))
|
28 |
+
|
29 |
+
with torch.no_grad():
|
30 |
+
bias_audio = waveglow.infer(mel_input, sigma=0.0).float()
|
31 |
+
bias_spec, _ = self.stft.transform(bias_audio)
|
32 |
+
|
33 |
+
self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None])
|
34 |
+
|
35 |
+
def forward(self, audio, strength=0.1):
|
36 |
+
audio_spec, audio_angles = self.stft.transform(audio.cuda().float())
|
37 |
+
audio_spec_denoised = audio_spec - self.bias_spec * strength
|
38 |
+
audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
|
39 |
+
audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles)
|
40 |
+
return audio_denoised
|
fairseq/examples/textless_nlp/gslm/unit2speech/tts_data.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
from examples.textless_nlp.gslm.unit2speech.tacotron2.text import (
|
10 |
+
EOS_TOK,
|
11 |
+
SOS_TOK,
|
12 |
+
code_to_sequence,
|
13 |
+
text_to_sequence,
|
14 |
+
)
|
15 |
+
from examples.textless_nlp.gslm.unit2speech.tacotron2.utils import (
|
16 |
+
load_code_dict,
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
class TacotronInputDataset:
|
21 |
+
def __init__(self, hparams, append_str=""):
|
22 |
+
self.is_text = getattr(hparams, "text_or_code", "text") == "text"
|
23 |
+
if not self.is_text:
|
24 |
+
self.code_dict = load_code_dict(
|
25 |
+
hparams.code_dict, hparams.add_sos, hparams.add_eos
|
26 |
+
)
|
27 |
+
self.code_key = hparams.code_key
|
28 |
+
self.add_sos = hparams.add_sos
|
29 |
+
self.add_eos = hparams.add_eos
|
30 |
+
self.collapse_code = hparams.collapse_code
|
31 |
+
self.append_str = append_str
|
32 |
+
|
33 |
+
def process_code(self, inp_str):
|
34 |
+
inp_toks = inp_str.split()
|
35 |
+
if self.add_sos:
|
36 |
+
inp_toks = [SOS_TOK] + inp_toks
|
37 |
+
if self.add_eos:
|
38 |
+
inp_toks = inp_toks + [EOS_TOK]
|
39 |
+
return code_to_sequence(inp_toks, self.code_dict, self.collapse_code)
|
40 |
+
|
41 |
+
def process_text(self, inp_str):
|
42 |
+
return text_to_sequence(inp_str, ["english_cleaners"])
|
43 |
+
|
44 |
+
def get_tensor(self, inp_str):
|
45 |
+
# uid, txt, inp_str = self._get_data(idx)
|
46 |
+
inp_str = inp_str + self.append_str
|
47 |
+
if self.is_text:
|
48 |
+
inp_toks = self.process_text(inp_str)
|
49 |
+
else:
|
50 |
+
inp_toks = self.process_code(inp_str)
|
51 |
+
return torch.from_numpy(np.array(inp_toks)).long()
|
52 |
+
|
53 |
+
def __len__(self):
|
54 |
+
return len(self.data)
|
fairseq/examples/textless_nlp/gslm/unit2speech/utils.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from examples.textless_nlp.gslm.unit2speech.tacotron2.model import Tacotron2
|
9 |
+
from examples.textless_nlp.gslm.unit2speech.tacotron2.waveglow_denoiser import (
|
10 |
+
Denoiser,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
def load_quantized_audio_from_file(file_path):
|
15 |
+
base_fname_batch, quantized_units_batch = [], []
|
16 |
+
with open(file_path) as f:
|
17 |
+
for line in f:
|
18 |
+
base_fname, quantized_units_str = line.rstrip().split("|")
|
19 |
+
quantized_units = [int(q) for q in quantized_units_str.split(" ")]
|
20 |
+
base_fname_batch.append(base_fname)
|
21 |
+
quantized_units_batch.append(quantized_units)
|
22 |
+
return base_fname_batch, quantized_units_batch
|
23 |
+
|
24 |
+
|
25 |
+
def synthesize_audio(model, waveglow, denoiser, inp, lab=None, strength=0.0):
|
26 |
+
assert inp.size(0) == 1
|
27 |
+
inp = inp.cuda()
|
28 |
+
if lab is not None:
|
29 |
+
lab = torch.LongTensor(1).cuda().fill_(lab)
|
30 |
+
|
31 |
+
with torch.no_grad():
|
32 |
+
_, mel, _, ali, has_eos = model.inference(inp, lab, ret_has_eos=True)
|
33 |
+
aud = waveglow.infer(mel, sigma=0.666)
|
34 |
+
aud_dn = denoiser(aud, strength=strength).squeeze(1)
|
35 |
+
return mel, aud, aud_dn, has_eos
|
36 |
+
|
37 |
+
|
38 |
+
def load_tacotron(tacotron_model_path, max_decoder_steps):
|
39 |
+
ckpt_dict = torch.load(tacotron_model_path)
|
40 |
+
hparams = ckpt_dict["hparams"]
|
41 |
+
hparams.max_decoder_steps = max_decoder_steps
|
42 |
+
sr = hparams.sampling_rate
|
43 |
+
model = Tacotron2(hparams)
|
44 |
+
model.load_state_dict(ckpt_dict["model_dict"])
|
45 |
+
model = model.cuda().eval().half()
|
46 |
+
return model, sr, hparams
|
47 |
+
|
48 |
+
|
49 |
+
def load_waveglow(waveglow_path):
|
50 |
+
waveglow = torch.load(waveglow_path)["model"]
|
51 |
+
waveglow = waveglow.cuda().eval().half()
|
52 |
+
for k in waveglow.convinv:
|
53 |
+
k.float()
|
54 |
+
denoiser = Denoiser(waveglow)
|
55 |
+
return waveglow, denoiser
|
fairseq/examples/textless_nlp/pgslm/README.md
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Text-Free Prosody-Aware Generative Spoken Language Modeling
|
2 |
+
|
3 |
+
This folder contains code and recipes to reproduce results reported in a paper _Text-Free Prosody-Aware Generative Spoken Language Modeling_,
|
4 |
+
Eugene Kharitonov*, Ann Lee*, Adam Polyak, Yossi Adi, Jade Copet, Kushal Lakhotia, Tu-Anh Nguyen, Morgane Rivière, Abdelrahman Mohamed, Emmanuel Dupoux, Wei-Ning Hsu, 2021. arxiv/2109.03264 [[arxiv]](https://arxiv.org/abs/2109.03264).
|
5 |
+
|
6 |
+
`*` denotes equal contribution.
|
7 |
+
|
8 |
+
You can find demo samples [[here]](https://speechbot.github.io/pgslm/index.html).
|
9 |
+
|
10 |
+
<details>
|
11 |
+
<summary>If you find this code useful, please consider citing our work using this bibtex </summary>
|
12 |
+
|
13 |
+
```
|
14 |
+
@misc{Kharitonov2021,
|
15 |
+
title={Text-Free Prosody-Aware Generative Spoken Language Modeling},
|
16 |
+
author={Eugene Kharitonov and Ann Lee and Adam Polyak and Yossi Adi and Jade Copet and Kushal Lakhotia and Tu-Anh Nguyen and Morgane Rivière and Abdelrahman Mohamed and Emmanuel Dupoux and Wei-Ning Hsu},
|
17 |
+
year={2021},
|
18 |
+
eprint={2109.03264},
|
19 |
+
archivePrefix={arXiv},
|
20 |
+
primaryClass={cs.CL}
|
21 |
+
}
|
22 |
+
```
|
23 |
+
</details>
|
24 |
+
|
25 |
+
|
26 |
+
## Additional requirements
|
27 |
+
Three packages are required in addition to fairseq, they are installable with pip:
|
28 |
+
```bash
|
29 |
+
pip install AMFM-decompy SoundFile scipy sklearn torchaudio npy-append-array
|
30 |
+
```
|
31 |
+
|
32 |
+
## Data preprocessing
|
33 |
+
|
34 |
+
### Prepare unit pseudo-text transcriptions of the audio
|
35 |
+
To get unit trascripts of the speech data we rely on the preprocessing steps of [GSLM](https://github.com/pytorch/fairseq/tree/main/examples/textless_nlp/gslm/speech2unit/) work.
|
36 |
+
|
37 |
+
Firstly, we will need to prepare manifest files for the dataset we want to preprocess
|
38 |
+
```
|
39 |
+
mkdir manifests/
|
40 |
+
python examples/wav2vec/wav2vec_manifest.py --valid-percent=0.0 $DATA_PATH --dest=manifests/train/
|
41 |
+
```
|
42 |
+
Next, we need a pre-trained HuBERT-base-ls960 model [[download]](https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt) and a corresponding kmeans-100 quantizer [[download]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/km100/km.bin). Having those we can quantize the dataset:
|
43 |
+
```
|
44 |
+
python examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py \
|
45 |
+
--feature_type hubert \
|
46 |
+
--kmeans_model_path km.bin \
|
47 |
+
--acoustic_model_path hubert_base_ls960.pt \
|
48 |
+
--layer 6 \
|
49 |
+
--manifest_path manifests/train/train.tsv \
|
50 |
+
--out_quantized_file_path manifests/train/units
|
51 |
+
```
|
52 |
+
|
53 |
+
Finally, by running
|
54 |
+
```
|
55 |
+
python examples/textless_nlp/pgslm/scripts/join_units_manifest.py --manifest=manifests/train/train.tsv --units=manifests/train/units --output=train.txt
|
56 |
+
```
|
57 |
+
We will get the training data description `train.txt` in the format that pGSLM expects. The above steps have to be repeated for
|
58 |
+
dev/test sets. Importantly, we rely on an assumption that the directories are structured as in LibriSpeech, i.e. the file paths follow the
|
59 |
+
`<spk_id>/<session_id>/<sample_id>.wav` format.
|
60 |
+
|
61 |
+
### Preprocess data for pGSLM
|
62 |
+
The very first step is to obtain the F0 quantization bins.
|
63 |
+
Assume the vocoder training manifest is `vocoder_train.txt` (in pGSLM data format prepared with the same process above).
|
64 |
+
We prepare the quantized F0 from the vocoder training data by running
|
65 |
+
```sh
|
66 |
+
bash examples/textless_nlp/pgslm/scripts/prepare_f0_quantization.sh \
|
67 |
+
vocoder_train.txt <sample_rate> 32 <preprocessed_dir> <output_prefix> # we use 32 bins in the paper
|
68 |
+
```
|
69 |
+
- `<sample_rate>`: sampling rate of the audio files in the manifest
|
70 |
+
- `<preprocessed_dir>`: where to output the output files
|
71 |
+
- `<output_prefix>`: prefix of the output files
|
72 |
+
|
73 |
+
The script will generate
|
74 |
+
- `<output_prefix>.f0_stat.pt`: the speaker-level F0 statistics, which can be used in vocoder training
|
75 |
+
- `<output_prefix>_mean_norm_log_f0_bin.th`: the quantized F0, which should be used in `prepare_data.sh` below
|
76 |
+
|
77 |
+
**Note:** See "Pre-trained models" for the pre-computed speaker-level F0 statistics and quantized F0 bins. We suggest using the pre-computed statistics for the data preparation below in order to take advantage of the pre-trained vocoder for waveform generation.
|
78 |
+
|
79 |
+
Next prepare the pGSLM data.
|
80 |
+
Assume train/valid/test manifests are `{train,valid,test}.txt`.
|
81 |
+
Here is an example of how to preprocess data:
|
82 |
+
|
83 |
+
```sh
|
84 |
+
bash examples/textless_nlp/pgslm/scripts/prepare_data.sh \
|
85 |
+
train.txt valid.txt test.txt <n_unit> <hop_size> <sample_rate> \
|
86 |
+
<preprocessed_dir>/<output_prefix>_mean_norm_log_f0_bin.th <preprocessed_dir>
|
87 |
+
```
|
88 |
+
- `<n_unit>`: discrete unit vocabulary size (we used a kmeans quantizer with the number of units equal to 100 in the example above)
|
89 |
+
- `<hop_size>`: downsampling rate relative to the waveform (e.g., 320 for HuBERT units)
|
90 |
+
- `<sample_rate>`: sampling rate of the audio files in the manifest
|
91 |
+
- `<preprocessed_dir>`: where to output the preprocessed files
|
92 |
+
|
93 |
+
This will create the dataset json config used for the next section at
|
94 |
+
`<preprocessed_dir>/data_config.json`.
|
95 |
+
|
96 |
+
Note that the example script uses only one thread to compute F0, which can take
|
97 |
+
_very long_ for preprocessing large datasets. It is suggested to distribute
|
98 |
+
jobs over multiple nodes/processes with `--nshards=x` and `--rank=z` (where z is
|
99 |
+
in [1, x]) in `preprocess_f0.py`, and set `--nshards_list=x` in
|
100 |
+
`prepare_data.py` correspondingly to collect sharded F0 data.
|
101 |
+
|
102 |
+
Now, everything is ready for training a model.
|
103 |
+
|
104 |
+
## Training Multi-Stream Transformer Unit Language Model (MS-TLM)
|
105 |
+
|
106 |
+
Below is an example command that trains Multi-Stream Transformer Language Model (MS-TLM) on a prepared dataset:
|
107 |
+
```bash
|
108 |
+
DATASET=data_config.json
|
109 |
+
|
110 |
+
fairseq-train $DATASET \
|
111 |
+
--task=speech_unit_modeling \
|
112 |
+
--arch="transformer_ulm_tiny" \
|
113 |
+
--criterion=speech_unit_lm_criterion \
|
114 |
+
--share-decoder-input-output-embed \
|
115 |
+
--dropout=0.1 \
|
116 |
+
--attention-dropout=0.1 \
|
117 |
+
--optimizer="adam" \
|
118 |
+
--adam-betas="(0.9, 0.98)" \
|
119 |
+
--clip-norm=1.0 \
|
120 |
+
--lr=0.0005 \
|
121 |
+
--lr-scheduler="inverse_sqrt" \
|
122 |
+
--warmup-updates=4000 \
|
123 |
+
--warmup-init-lr=1e-07 \
|
124 |
+
--tokens-per-sample=3072 \
|
125 |
+
--max-tokens=3072 \
|
126 |
+
--update-freq=4 \
|
127 |
+
--max-epoch=70 \
|
128 |
+
--num-workers=0 \
|
129 |
+
--skip-invalid-size-inputs-valid-test \
|
130 |
+
--loss-weights="1.0;0.5;0.0" \
|
131 |
+
--ignore-f0-input \
|
132 |
+
--checkpoint-activations \
|
133 |
+
--fp16 \
|
134 |
+
--max-target-positions=4096 \
|
135 |
+
--stream-shifts="1,1" \
|
136 |
+
--log-f0 --normalize-f0-mean --interpolate-f0 \
|
137 |
+
--ignore-unused-valid-subsets \
|
138 |
+
--discrete-duration --discrete-f0
|
139 |
+
```
|
140 |
+
|
141 |
+
Some of the important parameters that are specific to MS-TLM:
|
142 |
+
* `arch`: specifies the Transformer architecture used. Supported options are:
|
143 |
+
* `transformer_ulm_tiny` - a tiny model that can be used for debugging; it has 2 layers, 1 attention head, FFN and embedding dimensions of 64,
|
144 |
+
* `transformer_ulm` - a base model with 6 layers, 8 heads, embedding dimension 512, and FFN dimensionality of 2048,
|
145 |
+
* `transformer_ulm_big` - the largest model we experiment with in the paper: 12-layer/16 heads, 1024/4096 embedding and FFN dimensions;
|
146 |
+
* `loss-weights`: this parameter sets importance weights (must be non-negative) for the components of the loss that correspond to unit, duration, and F0 streams. To turn off a component of the loss, its weight has to be set to 0. For instance, to predict only unit stream the parameter should be set to "1;0;0";
|
147 |
+
* `stream-shifts`: specifies relative shifts of the two prosodic streams w.r.t. the unit stream (duration and F0, respectively). No shift corresponds to "0,0";
|
148 |
+
* `ignore-duration-input`/`ignore-f0-input`: setting these flags would zero-out correpsonding input streams;
|
149 |
+
* `max-token-duration`: duration values would be max-capped by the specified value;
|
150 |
+
* `discrete-duration`/`discrete-f0`: whether duration and F0 streams should be quantized;
|
151 |
+
* `log_f0`, `normalize-f0-mean`, `normalize-f0-std`, `interpolate-f0`: configure how F0 stream is treated. `log_f0` sets up modelling in the log-space, `normalize-f0-mean`/`normalize-f0-std` control per-speaker normalization, and `interpolate-f0` enables F0 interpolation for unvoiced regions where F0 was set to 0,
|
152 |
+
* `mask-dur-prob`, `mask-f0-prob`, `mask-dur-seg-prob`, `mask-f0-seg-prob`, `mask-unit-seg-prob`, `mask-unit-seg-leng`: this family of parameters sets the probababilities of masking individual steps and spans on each stream as well as lengths of the maked spans.
|
153 |
+
|
154 |
+
|
155 |
+
## Pre-trained models
|
156 |
+
### MS-TLM
|
157 |
+
Below you can find checkpoints for four best-performing models from the paper (IDs 9..12 in Table 1). These models are trained on Hubert-100 transcripts of the LibriLight-6K dataset. They have the prosody streams shifted by 1 w.r.t. the unit stream. All models predict all three streams (units, duration, and F0), but two
|
158 |
+
of them only have unit steam in their input.
|
159 |
+
|
160 |
+
| | Continuous prosody | Quantized prosody |
|
161 |
+
|-------------------|--------------------|-------------------|
|
162 |
+
| No prosody input | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/ulm_checkpoints/continuous_no_prosody_shift_1_1.pt) | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/ulm_checkpoints/discrete_no_prosody_shift_1_1.pt) |
|
163 |
+
| Has prosody input | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/ulm_checkpoints/continuous_prosody_shift_1_1.pt) | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/ulm_checkpoints/discrete_prosody_shift_1_1.pt)|
|
164 |
+
|
165 |
+
The optimal per-stream sampling temperatures/scaling parameters that we have identified for these models, in the (`T-token, T-duration, T-f0`) format:
|
166 |
+
|
167 |
+
| | Continuous prosody | Quantized prosody |
|
168 |
+
|-------------------|--------------------|-------------------|
|
169 |
+
| No prosody input | 0.7, 0.125, 0.0003125| 0.7, 0.25, 0.5 |
|
170 |
+
| Has prosody input | 0.7, 0.125, 0.00125 | 0.7, 0.25, 0.7 |
|
171 |
+
|
172 |
+
## Vocoder
|
173 |
+
| Units | Prosody | F0 stats | Checkpoint | Config |
|
174 |
+
|-------------------|---------|--------------|------------|--------|
|
175 |
+
| HuBERT-base-ls960, kmeans-100 | [[Quantized 32 bins]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/vocoder/blizzard2013/mean_norm_log_f0_seg_bin.th) | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/vocoder/blizzard2013/f0_stats.pt) | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/vocoder/blizzard2013/naive_quant_32_norm_log_seg_hubert/checkpoint.pt) | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/vocoder/blizzard2013/naive_quant_32_norm_log_seg_hubert/config.json) |
|
176 |
+
| HuBERT-base-ls960, kmeans-100 | Continuous | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/vocoder/blizzard2013/f0_stats.pt) | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/vocoder/blizzard2013/mean_norm_log_f0_hubert/checkpoint.pt) | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/vocoder/blizzard2013/mean_norm_log_f0_hubert/config.json) |
|
177 |
+
|
178 |
+
|
179 |
+
## Evaluating a trained model
|
180 |
+
Evaluation is done with the `eval/cont_metrics.py` scripts. As described in the paper, there are several metrics used.
|
181 |
+
|
182 |
+
**Teacher-forced metrics**
|
183 |
+
```bash
|
184 |
+
SET=valid
|
185 |
+
CHECKPOINT_PATH=discrete_prosody_shift_1_1.pt
|
186 |
+
DATA=data_config.json
|
187 |
+
|
188 |
+
python examples/textless_nlp/pgslm/eval/cont_metrics.py $DATA \
|
189 |
+
--metric=teacher_force_everything \
|
190 |
+
--path=$CHECKPOINT_PATH \
|
191 |
+
--batch-size=16 \
|
192 |
+
--fp16 \
|
193 |
+
--seed=111 \
|
194 |
+
--eval-subset=$SET \
|
195 |
+
--f0-discretization-bounds=mean_norm_log_f0_seg_bin.th --dequantize-prosody
|
196 |
+
```
|
197 |
+
(Using this command, our provided `discrete_prosody_shift_1_1.pt` checkpoint should produce `{'token_loss': 1.408..., 'duration_loss': 0.5424..., 'f0_loss': 0.0474...}` on LibriSpeech dev-clean).
|
198 |
+
|
199 |
+
The parameters `--f0-discretization-bounds=mean_norm_log_f0_seg_bin.th --dequantize-prosody` are specific for quantized-prosody models. They signal that the prosody streams must be decoded into the continuous domain before calculating correlation. It is the same `*_mean_norm_log_f0_bin.th` file as we prepared before.
|
200 |
+
The `mean_norm_log_f0_seg_bin.th` file we used with the pre-trained models can be downloaded [[here]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/vocoder/blizzard2013/mean_norm_log_f0_seg_bin.th).
|
201 |
+
|
202 |
+
|
203 |
+
**Consistency (aka Correlation) metrics**
|
204 |
+
|
205 |
+
The following command estimates correlation between mean values of the F0 stream in the prompt and in the generated continuation (unit and duration steams are fixed).
|
206 |
+
|
207 |
+
```bash
|
208 |
+
T_F0=0.7
|
209 |
+
EXPLOSION=20
|
210 |
+
SET=test
|
211 |
+
CHECKPOINT_PATH=discrete_prosody_shift_1_1.pt
|
212 |
+
DATA=data_config.json
|
213 |
+
|
214 |
+
python examples/textless_nlp/pgslm/eval/cont_metrics.py $DATA \
|
215 |
+
--prefix-length=150 \
|
216 |
+
--metric=correlation \
|
217 |
+
--path=$CHECKPOINT_PATH \
|
218 |
+
--batch-size=16 \
|
219 |
+
--fp16 \
|
220 |
+
--seed=111 \
|
221 |
+
--teacher-force-tokens \
|
222 |
+
--teacher-force-duration \
|
223 |
+
--min-length=300 \
|
224 |
+
--batch-explosion-rate=$EXPLOSION \
|
225 |
+
--T-f0=$T_F0 \
|
226 |
+
--eval-subset=$SET \
|
227 |
+
--f0-discretization-bounds=mean_norm_log_f0_seg_bin.th \
|
228 |
+
--dequantize-prosody --n-workers=8
|
229 |
+
```
|
230 |
+
(Using this command, our provided `discrete_prosody_shift_1_1.pt` checkpoint should produce `{...'F0 corr': 0.315 ..}` on LibriSpeech test-clean).
|
231 |
+
|
232 |
+
* By using flags `--teacher-force-tokens, --teacher-force-duration, --teacher-force-f0` one can calculate correlations along each stream while having other two streams fixed to ground-truth values (or freeze all three streams to get ground-truth correlation values);
|
233 |
+
* The parameters `T-f0`, `T-duration`, and `T-token` specify per-stream temperatures and, in the case of continuous-valued prosody, scaling parameter of the corresponding Laplace distribution (setting a temperature to 0 will enforce greedy sampling);
|
234 |
+
* `min-length` filters out sequences that are shorter then 300 duration units (i.e. 6s in the case of Hubert units);
|
235 |
+
* `prefix-length` specifies that we want to use first 150 duration units are prompt (i.e. 3s in the case of Hubert units)
|
236 |
+
|
237 |
+
|
238 |
+
**Correctness (aka Continuation) and Expressiveness (aka Std) metrics**
|
239 |
+
|
240 |
+
By running the following command, we can get minMAE and Std for the log-F0 stream for the model with quantized prosody.
|
241 |
+
```bash
|
242 |
+
DATA=data_config.json
|
243 |
+
EXPLOSION=20
|
244 |
+
SET=test
|
245 |
+
CHECKPOINT_PATH=discrete_prosody_shift_1_1.pt
|
246 |
+
T_F0=0.7
|
247 |
+
|
248 |
+
python examples/textless_nlp/pgslm/eval/cont_metrics.py $DATA \
|
249 |
+
--prefix-length=150 \
|
250 |
+
--metric=continuation \
|
251 |
+
--path=$CHECKPOINT_PATH \
|
252 |
+
--batch-size=16 \
|
253 |
+
--fp16 \
|
254 |
+
--seed=111 \
|
255 |
+
--batch-explosion-rate=$EXPLOSION \
|
256 |
+
--teacher-force-tokens \
|
257 |
+
--teacher-force-duration \
|
258 |
+
--T-f0=$T_F0 \
|
259 |
+
--eval-subset=$SET \
|
260 |
+
--f0-discretization-bounds=mean_norm_log_f0_seg_bin.th --dequantize-prosody
|
261 |
+
```
|
262 |
+
(Using this command, our provided `discrete_prosody_shift_1_1.pt` checkpoint should produce `{...'F0 MAE': 0.0772, 'F0 Std': 0.1489...}` on LibriSpeech test-clean).
|
263 |
+
|
264 |
+
Again, by setting `--teacher-force-tokens, --teacher-force-duration, --teacher-force-f0` we can calculate Token BLEU for the token stream (when `--teacher-force-duration` & `--teacher-force-f0` are on) and per-stream min MAE for each prosody stream individually.
|
265 |
+
|
266 |
+
Finally, `cont_metrics.py` allows to specify the number of workers (e.g., `n-workers=8`) which allows to speed up the computation by spreading multiple worker processes
|
267 |
+
over the available GPUs.
|
268 |
+
|
269 |
+
**Cont Word BLEU**
|
270 |
+
|
271 |
+
We used the code and the evaluation protocol of [(Lakhotia et al., 2021)](https://arxiv.org/abs/2102.01192).
|
272 |
+
|
273 |
+
## Sampling from a trained model
|
274 |
+
|
275 |
+
To get (prompted or not) samples from a trained model it is enough to run `sample.py`:
|
276 |
+
```bash
|
277 |
+
CHECKPOINT_PATH=checkpoints/checkpoint_best.pt
|
278 |
+
DATASET=examples/textless_nlp/pgslm/repro/dataset/data_config.json
|
279 |
+
python examples/textless_nlp/pgslm/sample/sample.py $DATASET \
|
280 |
+
--output=$SAMPLES \
|
281 |
+
--path=$CHECKPOINT_PATH \
|
282 |
+
--sampling \
|
283 |
+
--T-token=0.7 \
|
284 |
+
--T-duration=0.25 \
|
285 |
+
--T-f0=0.7 \
|
286 |
+
--max-length=500 \
|
287 |
+
--prefix-length=150 \
|
288 |
+
--subset=valid \
|
289 |
+
--seed=1 \
|
290 |
+
--match-duration \
|
291 |
+
--code-type=hubert \
|
292 |
+
--batch-explosion-rate=2
|
293 |
+
```
|
294 |
+
|
295 |
+
Some useful parameters:
|
296 |
+
* `T-token`, `T-duration`, `T-f0` specify sampling temperature for the three streams. Setting a temperature to `0` switches sample to the greedy (argmax) one;
|
297 |
+
* `prefix-length`: length of the prompt, measured in timesteps (e.g. for Hubert (CPC) each timestep is 20 (10) ms);
|
298 |
+
* `subset`: which subset of the dataset to use as prompts (can be `train`, `valid`, `test`);
|
299 |
+
* `teacher-force-tokens`, `teacher-force-duration`, `teacher-force-f0`: if set, at each autoregressive step, ground-truth values replace the produced one;
|
300 |
+
* `short-curcuit`: replace sampling by ground-truth inputs;
|
301 |
+
* `match-duration`: forces the produced sample to have the same duration (in time), as the entire sequence (beyond the prompt if there is any);
|
302 |
+
* `batch-explosion-rate`: number of samples per prompt;
|
303 |
+
* `f0-discretization-bounds`: path to a file with quantization boundaries. If it is set, F0 values are de-quantized back to the continuous domain
|
304 |
+
(the model must be a quanized one);
|
305 |
+
* `max-length` sets the maximal number of segment steps to be produced.
|
306 |
+
|
307 |
+
Note that `sample.py` automatically uses all available GPUs, to avoid that please use environment variable `CUDA_VISIBLE_DEVICES`.
|
308 |
+
|
309 |
+
## Vocoding samples
|
310 |
+
To generate audios for output from `sample.py` (`$IN_FILE`):
|
311 |
+
```bash
|
312 |
+
python examples/textless_nlp/pgslm/generate_waveform.py \
|
313 |
+
--in-file=$IN_FILE \
|
314 |
+
--vocoder=$VODOER \
|
315 |
+
--vocoder-cfg=$VOCODER_CFG \
|
316 |
+
--results-path=$RESULTS_PATH
|
317 |
+
```
|
318 |
+
See "Pre-trained model" for `$VOCODER` and `VOCODER_CFG`.
|
fairseq/examples/textless_nlp/pgslm/data_utils.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
|
12 |
+
class Stat:
|
13 |
+
def __init__(self, keep_raw=False):
|
14 |
+
self.x = 0.0
|
15 |
+
self.x2 = 0.0
|
16 |
+
self.z = 0.0 # z = logx
|
17 |
+
self.z2 = 0.0
|
18 |
+
self.n = 0.0
|
19 |
+
self.u = 0.0
|
20 |
+
self.keep_raw = keep_raw
|
21 |
+
self.raw = []
|
22 |
+
|
23 |
+
def update(self, new_x):
|
24 |
+
new_z = new_x.log()
|
25 |
+
|
26 |
+
self.x += new_x.sum()
|
27 |
+
self.x2 += (new_x**2).sum()
|
28 |
+
self.z += new_z.sum()
|
29 |
+
self.z2 += (new_z**2).sum()
|
30 |
+
self.n += len(new_x)
|
31 |
+
self.u += 1
|
32 |
+
|
33 |
+
if self.keep_raw:
|
34 |
+
self.raw.append(new_x)
|
35 |
+
|
36 |
+
@property
|
37 |
+
def mean(self):
|
38 |
+
return self.x / self.n
|
39 |
+
|
40 |
+
@property
|
41 |
+
def std(self):
|
42 |
+
return (self.x2 / self.n - self.mean**2) ** 0.5
|
43 |
+
|
44 |
+
@property
|
45 |
+
def mean_log(self):
|
46 |
+
return self.z / self.n
|
47 |
+
|
48 |
+
@property
|
49 |
+
def std_log(self):
|
50 |
+
return (self.z2 / self.n - self.mean_log**2) ** 0.5
|
51 |
+
|
52 |
+
@property
|
53 |
+
def n_frms(self):
|
54 |
+
return self.n
|
55 |
+
|
56 |
+
@property
|
57 |
+
def n_utts(self):
|
58 |
+
return self.u
|
59 |
+
|
60 |
+
@property
|
61 |
+
def raw_data(self):
|
62 |
+
assert self.keep_raw, "does not support storing raw data!"
|
63 |
+
return torch.cat(self.raw)
|
64 |
+
|
65 |
+
|
66 |
+
class F0Stat(Stat):
|
67 |
+
def update(self, new_x):
|
68 |
+
# assume unvoiced frames are 0 and consider only voiced frames
|
69 |
+
if new_x is not None:
|
70 |
+
super().update(new_x[new_x != 0])
|
71 |
+
|
72 |
+
|
73 |
+
def dump_speaker_f0_stat(speaker_to_f0_stat, out_prefix):
|
74 |
+
path = f"{out_prefix}.f0_stat.pt"
|
75 |
+
assert not os.path.exists(path)
|
76 |
+
|
77 |
+
d = {
|
78 |
+
speaker: {
|
79 |
+
"f0_mean": speaker_to_f0_stat[speaker].mean,
|
80 |
+
"f0_std": speaker_to_f0_stat[speaker].std,
|
81 |
+
"logf0_mean": speaker_to_f0_stat[speaker].mean_log,
|
82 |
+
"logf0_std": speaker_to_f0_stat[speaker].std_log,
|
83 |
+
}
|
84 |
+
for speaker in speaker_to_f0_stat
|
85 |
+
}
|
86 |
+
torch.save(d, path)
|
87 |
+
|
88 |
+
return d
|
89 |
+
|
90 |
+
|
91 |
+
def load_audio_path(path):
|
92 |
+
audio_paths = []
|
93 |
+
with open(path) as f:
|
94 |
+
for line in f.readlines():
|
95 |
+
sample = eval(line.strip())
|
96 |
+
audio_paths.append(sample["audio"])
|
97 |
+
|
98 |
+
return audio_paths
|
99 |
+
|
100 |
+
|
101 |
+
def load_f0(f0_dir, nshards):
|
102 |
+
path_to_f0 = {}
|
103 |
+
for rank in tqdm(range(1, nshards + 1), desc=f"load f0"):
|
104 |
+
f0_shard_path = f"{f0_dir}/f0_{rank}_{nshards}.pt"
|
105 |
+
shard_path_to_f0 = torch.load(f0_shard_path)
|
106 |
+
path_to_f0.update(shard_path_to_f0)
|
107 |
+
return path_to_f0
|
fairseq/examples/textless_nlp/pgslm/eval/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
fairseq/examples/textless_nlp/pgslm/eval/cont_metrics.py
ADDED
@@ -0,0 +1,730 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import numpy as np
|
8 |
+
import scipy
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.multiprocessing as mp
|
12 |
+
from fairseq import checkpoint_utils, options
|
13 |
+
from fairseq.data.codedataset import CodeDataset, ExpressiveCodeDataConfig
|
14 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
15 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
16 |
+
from fairseq.utils import move_to_cuda
|
17 |
+
from fairseq import utils
|
18 |
+
from fairseq.criterions.speech_ulm_criterion import nll_loss, mae_loss
|
19 |
+
|
20 |
+
import time
|
21 |
+
from types import SimpleNamespace
|
22 |
+
|
23 |
+
import sys, pathlib
|
24 |
+
|
25 |
+
sys.path.append(str(pathlib.Path(__file__).parent.parent.resolve()))
|
26 |
+
|
27 |
+
from naive_decoder import Naive_F0_Decoder
|
28 |
+
from inference_dataset import InferenceDataset, explode_batch
|
29 |
+
from sample.sample import do_sampling, TemperatureDecoder, FilterNamesDataset
|
30 |
+
|
31 |
+
try:
|
32 |
+
from nltk.translate.bleu_score import sentence_bleu
|
33 |
+
except ImportError:
|
34 |
+
print("Please install nltk: `pip install --user -U nltk`")
|
35 |
+
raise
|
36 |
+
|
37 |
+
|
38 |
+
@torch.no_grad()
|
39 |
+
def teacher_force_everything(
|
40 |
+
args, dataset, model, criterion, tgt_dict, rank, world_size
|
41 |
+
):
|
42 |
+
prefix = args.prefix_length
|
43 |
+
|
44 |
+
f0_decoder = None
|
45 |
+
if args.dequantize_prosody:
|
46 |
+
assert dataset.discrete_f0
|
47 |
+
print("Reporting MAE for a discrete model")
|
48 |
+
f0_decoder = Naive_F0_Decoder(
|
49 |
+
args.f0_discretization_bounds, dataset.config.f0_vq_n_units
|
50 |
+
).cuda()
|
51 |
+
|
52 |
+
dataset = InferenceDataset(
|
53 |
+
dataset,
|
54 |
+
prefix=args.prefix_length,
|
55 |
+
only_prefix=False,
|
56 |
+
filter_short=True,
|
57 |
+
presort_by_length=True,
|
58 |
+
)
|
59 |
+
sampler = (
|
60 |
+
None
|
61 |
+
if world_size == 1
|
62 |
+
else DistributedSampler(
|
63 |
+
dataset, num_replicas=world_size, rank=rank, shuffle=False
|
64 |
+
)
|
65 |
+
)
|
66 |
+
dataloader = DataLoader(
|
67 |
+
dataset,
|
68 |
+
args.batch_size,
|
69 |
+
shuffle=False,
|
70 |
+
collate_fn=dataset.collater,
|
71 |
+
sampler=sampler,
|
72 |
+
)
|
73 |
+
|
74 |
+
total_token_loss, total_duration_loss, total_f0_loss, total_tokens = (
|
75 |
+
0.0,
|
76 |
+
0.0,
|
77 |
+
0.0,
|
78 |
+
0.0,
|
79 |
+
)
|
80 |
+
|
81 |
+
i = 0
|
82 |
+
for batch in dataloader:
|
83 |
+
i += 1
|
84 |
+
batch = move_to_cuda(batch)
|
85 |
+
output = model(**batch["net_input"])
|
86 |
+
|
87 |
+
tokens, durations, f0 = output["token"], output["duration"], output["f0"]
|
88 |
+
durations, f0 = durations.squeeze(), f0.squeeze()
|
89 |
+
|
90 |
+
token_loss = nll_loss(
|
91 |
+
tokens[:, prefix - 1 :],
|
92 |
+
batch["target"][:, prefix - 1 :].contiguous(),
|
93 |
+
batch["mask"][:, prefix - 1 :].contiguous(),
|
94 |
+
reduce=True,
|
95 |
+
)
|
96 |
+
|
97 |
+
if args.dequantize_prosody:
|
98 |
+
durations = durations.argmax(dim=-1)
|
99 |
+
duration_loss = mae_loss(
|
100 |
+
durations[:, prefix - 1 :].contiguous().float(),
|
101 |
+
batch["dur_target"][:, prefix - 1 :].contiguous().float(),
|
102 |
+
batch["dur_mask"][:, prefix - 1 :].contiguous(),
|
103 |
+
reduce=True,
|
104 |
+
)
|
105 |
+
else:
|
106 |
+
duration_loss = criterion.dur_loss_fn(
|
107 |
+
durations[:, prefix - 1 :].contiguous(),
|
108 |
+
batch["dur_target"][:, prefix - 1 :].contiguous(),
|
109 |
+
batch["dur_mask"][:, prefix - 1 :].contiguous(),
|
110 |
+
reduce=True,
|
111 |
+
)
|
112 |
+
|
113 |
+
if f0_decoder:
|
114 |
+
f0 = f0.argmax(dim=-1)
|
115 |
+
f0 = f0_decoder(f0).squeeze(-1)
|
116 |
+
|
117 |
+
f0_target = batch["raw_f0"]
|
118 |
+
f0_loss = mae_loss(
|
119 |
+
f0[:, prefix - 1 :].contiguous(),
|
120 |
+
f0_target[:, prefix - 1 :].contiguous(),
|
121 |
+
batch["f0_mask"][:, prefix - 1 :].contiguous(),
|
122 |
+
reduce=True,
|
123 |
+
)
|
124 |
+
else:
|
125 |
+
f0_loss = criterion.f0_loss_fn(
|
126 |
+
f0[:, prefix - 1 :].contiguous(),
|
127 |
+
batch["f0_target"][:, prefix - 1 :].contiguous(),
|
128 |
+
batch["f0_mask"][:, prefix - 1 :].contiguous(),
|
129 |
+
reduce=True,
|
130 |
+
)
|
131 |
+
|
132 |
+
n_tokens = (~batch["dur_mask"])[:, prefix - 1 :].sum()
|
133 |
+
|
134 |
+
total_token_loss += token_loss.item()
|
135 |
+
total_duration_loss += duration_loss.item()
|
136 |
+
total_f0_loss += f0_loss.item()
|
137 |
+
|
138 |
+
total_tokens += n_tokens.item()
|
139 |
+
if args.debug and i > 5:
|
140 |
+
break
|
141 |
+
|
142 |
+
values = torch.tensor([total_token_loss, total_duration_loss, total_f0_loss])
|
143 |
+
normalizers = torch.tensor([total_tokens for _ in range(3)])
|
144 |
+
|
145 |
+
return values, normalizers
|
146 |
+
|
147 |
+
|
148 |
+
def get_bleu(produced_tokens, target_tokens, tgt_dict):
|
149 |
+
assert target_tokens.ndim == 1
|
150 |
+
assert produced_tokens.size(1) == target_tokens.size(0)
|
151 |
+
|
152 |
+
# we can have padding due to shifted channels
|
153 |
+
shift = 0
|
154 |
+
for token in reversed(target_tokens.cpu().tolist()):
|
155 |
+
if token in [tgt_dict.pad(), tgt_dict.eos()]:
|
156 |
+
shift += 1
|
157 |
+
else:
|
158 |
+
break
|
159 |
+
target_tokens = target_tokens[:-shift]
|
160 |
+
produced_tokens = produced_tokens[:, :-shift]
|
161 |
+
|
162 |
+
string_target = tgt_dict.string(target_tokens).split()
|
163 |
+
string_candidates = [
|
164 |
+
tgt_dict.string(produced_tokens[i, :]).split()
|
165 |
+
for i in range(produced_tokens.size(0))
|
166 |
+
]
|
167 |
+
|
168 |
+
bleu3 = sentence_bleu(
|
169 |
+
references=string_candidates,
|
170 |
+
hypothesis=string_target,
|
171 |
+
weights=(1.0 / 3, 1.0 / 3, 1.0 / 3),
|
172 |
+
)
|
173 |
+
return bleu3
|
174 |
+
|
175 |
+
|
176 |
+
@torch.no_grad()
|
177 |
+
def continuation(args, dataset, model, criterion, tgt_dict, rank, world_size):
|
178 |
+
is_discrete_duration = dataset.discrete_dur
|
179 |
+
is_discrete_f0 = dataset.discrete_f0
|
180 |
+
|
181 |
+
f0_decoder = None
|
182 |
+
if args.dequantize_prosody:
|
183 |
+
assert dataset.discrete_f0
|
184 |
+
print("Reporting MAE F0 for a discrete model")
|
185 |
+
f0_decoder = Naive_F0_Decoder(
|
186 |
+
args.f0_discretization_bounds, dataset.config.f0_vq_n_units
|
187 |
+
).cuda()
|
188 |
+
|
189 |
+
dataset = InferenceDataset(
|
190 |
+
dataset, args.prefix_length, filter_short=True, presort_by_length=True
|
191 |
+
)
|
192 |
+
sampler = (
|
193 |
+
None
|
194 |
+
if world_size == 1
|
195 |
+
else DistributedSampler(
|
196 |
+
dataset, num_replicas=world_size, rank=rank, shuffle=False
|
197 |
+
)
|
198 |
+
)
|
199 |
+
dataloader = DataLoader(
|
200 |
+
dataset,
|
201 |
+
batch_size=1,
|
202 |
+
shuffle=False,
|
203 |
+
collate_fn=dataset.collater,
|
204 |
+
sampler=sampler,
|
205 |
+
)
|
206 |
+
|
207 |
+
Ts = args.T_token, args.T_duration, args.T_f0
|
208 |
+
decoder = TemperatureDecoder(
|
209 |
+
Ts, discrete_dur=is_discrete_duration, discrete_f0=is_discrete_f0
|
210 |
+
)
|
211 |
+
|
212 |
+
running_stats = SimpleNamespace(
|
213 |
+
token_bleu=0.0,
|
214 |
+
duration_nll=0.0,
|
215 |
+
duration_mae=0.0,
|
216 |
+
f0_nll=0.0,
|
217 |
+
f0_mae=0.0,
|
218 |
+
n_tokens=0.0,
|
219 |
+
n_sentences=0.0,
|
220 |
+
f0_sum=0.0,
|
221 |
+
f0_sum_sq=0.0,
|
222 |
+
dur_sum=0.0,
|
223 |
+
dur_sum_sq=0.0,
|
224 |
+
)
|
225 |
+
|
226 |
+
for i, batch in enumerate(dataloader):
|
227 |
+
batch = explode_batch(batch, args.batch_explosion_rate)
|
228 |
+
bsz = batch["target"].size(0)
|
229 |
+
|
230 |
+
batch = move_to_cuda(batch)
|
231 |
+
prefix = batch["prefix"][0]
|
232 |
+
|
233 |
+
max_length_to_unroll = batch["target"].size(1)
|
234 |
+
prefix_length = batch["net_input"]["src_tokens"].size(1)
|
235 |
+
steps = max_length_to_unroll - prefix_length + 1
|
236 |
+
|
237 |
+
assert steps > 0
|
238 |
+
produced_tokens, produced_durations, produced_f0, outputs = do_sampling(
|
239 |
+
model,
|
240 |
+
batch,
|
241 |
+
tgt_dict.eos(),
|
242 |
+
decoder,
|
243 |
+
autoregressive_steps=steps,
|
244 |
+
teacher_force_tokens=args.teacher_force_tokens,
|
245 |
+
teacher_force_duration=args.teacher_force_duration,
|
246 |
+
teacher_force_f0=args.teacher_force_f0,
|
247 |
+
)
|
248 |
+
|
249 |
+
if args.teacher_force_tokens:
|
250 |
+
assert (produced_tokens[:, 1:] == batch["target"]).all()
|
251 |
+
if args.teacher_force_duration:
|
252 |
+
assert (produced_durations[:, 1:] == batch["dur_target"]).all()
|
253 |
+
if args.teacher_force_f0:
|
254 |
+
assert (produced_f0[:, 1:] == batch["f0_target"]).all()
|
255 |
+
|
256 |
+
dur_target = batch["dur_target"][:, prefix - 1 :].contiguous()
|
257 |
+
f0_target = batch["f0_target"][:, prefix - 1 :].contiguous()
|
258 |
+
|
259 |
+
f0_mask = batch["f0_mask"][:, prefix - 1 :].contiguous()
|
260 |
+
dur_mask = batch["dur_mask"][:, prefix - 1 :].contiguous()
|
261 |
+
|
262 |
+
duration_mae = mae_loss(
|
263 |
+
produced_durations[:, prefix:].float(),
|
264 |
+
dur_target.float(),
|
265 |
+
dur_mask,
|
266 |
+
reduce=False,
|
267 |
+
)
|
268 |
+
min_duration_mae = duration_mae.view(bsz, -1).sum(dim=-1).min(dim=0)[0]
|
269 |
+
running_stats.duration_mae += min_duration_mae
|
270 |
+
|
271 |
+
running_stats.dur_sum += (
|
272 |
+
produced_durations[:, prefix:].float() * (~dur_mask)
|
273 |
+
).sum() / args.batch_explosion_rate
|
274 |
+
running_stats.dur_sum_sq += (
|
275 |
+
produced_durations[:, prefix:].float() * (~dur_mask)
|
276 |
+
).pow(2.0).sum() / args.batch_explosion_rate
|
277 |
+
|
278 |
+
if is_discrete_duration:
|
279 |
+
duration_loss = criterion.dur_loss_fn(
|
280 |
+
torch.stack([x[1] for x in outputs], dim=1),
|
281 |
+
dur_target,
|
282 |
+
dur_mask,
|
283 |
+
reduce=False,
|
284 |
+
)
|
285 |
+
min_duration_loss = duration_loss.view(bsz, -1).sum(dim=-1).min(dim=0)[0]
|
286 |
+
running_stats.duration_nll += min_duration_loss
|
287 |
+
|
288 |
+
if f0_decoder: # can only exist for discrete F0 models
|
289 |
+
decoded_produced_f0 = f0_decoder(produced_f0[:, prefix:])
|
290 |
+
decoded_f0_target = batch["raw_f0"][:, prefix - 1 :].contiguous()
|
291 |
+
|
292 |
+
if produced_f0.ndim == 3:
|
293 |
+
decoded_produced_f0 = decoded_produced_f0.squeeze(2)
|
294 |
+
decoded_f0_target = decoded_f0_target.squeeze(2)
|
295 |
+
|
296 |
+
f0_mae = mae_loss(
|
297 |
+
decoded_produced_f0, decoded_f0_target, f0_mask, reduce=False
|
298 |
+
)
|
299 |
+
f0_mae = f0_mae.view(bsz, -1).sum(dim=-1).min(dim=0)[0]
|
300 |
+
running_stats.f0_mae += f0_mae
|
301 |
+
|
302 |
+
f0_loss = criterion.f0_loss_fn(
|
303 |
+
torch.stack([x[2] for x in outputs], dim=1),
|
304 |
+
f0_target.long(),
|
305 |
+
f0_mask,
|
306 |
+
reduce=False,
|
307 |
+
)
|
308 |
+
f0_loss = f0_loss.view(bsz, -1).sum(dim=-1).min(dim=0)[0]
|
309 |
+
running_stats.f0_nll += f0_loss
|
310 |
+
|
311 |
+
running_stats.f0_sum += (
|
312 |
+
decoded_produced_f0 * (~f0_mask)
|
313 |
+
).sum() / args.batch_explosion_rate
|
314 |
+
running_stats.f0_sum_sq += (decoded_produced_f0 * (~f0_mask)).pow(
|
315 |
+
2.0
|
316 |
+
).sum() / args.batch_explosion_rate
|
317 |
+
|
318 |
+
else:
|
319 |
+
assert not is_discrete_duration
|
320 |
+
|
321 |
+
f0_loss = mae_loss(
|
322 |
+
produced_f0[:, prefix:], f0_target, f0_mask, reduce=False
|
323 |
+
)
|
324 |
+
f0_loss = f0_loss.view(bsz, -1).sum(dim=-1).min(dim=0)[0]
|
325 |
+
running_stats.f0_mae += f0_loss
|
326 |
+
|
327 |
+
running_stats.f0_sum += (
|
328 |
+
produced_f0[:, prefix:].sum() / args.batch_explosion_rate
|
329 |
+
)
|
330 |
+
running_stats.f0_sum_sq += (
|
331 |
+
produced_f0[:, prefix:].pow(2.0).sum() / args.batch_explosion_rate
|
332 |
+
)
|
333 |
+
|
334 |
+
running_stats.n_tokens += (~dur_mask)[0, ...].sum()
|
335 |
+
|
336 |
+
token_loss = get_bleu(
|
337 |
+
produced_tokens[:, prefix:], batch["target"][0, prefix - 1 :], tgt_dict
|
338 |
+
)
|
339 |
+
running_stats.token_bleu += token_loss
|
340 |
+
running_stats.n_sentences += 1
|
341 |
+
|
342 |
+
if args.debug:
|
343 |
+
break
|
344 |
+
|
345 |
+
values = torch.tensor(
|
346 |
+
[
|
347 |
+
running_stats.token_bleu,
|
348 |
+
running_stats.duration_nll,
|
349 |
+
running_stats.duration_mae,
|
350 |
+
running_stats.f0_nll,
|
351 |
+
running_stats.f0_mae,
|
352 |
+
running_stats.f0_sum,
|
353 |
+
running_stats.f0_sum_sq,
|
354 |
+
running_stats.dur_sum,
|
355 |
+
running_stats.dur_sum_sq,
|
356 |
+
]
|
357 |
+
)
|
358 |
+
normalizers = torch.tensor(
|
359 |
+
[running_stats.n_sentences] + [running_stats.n_tokens] * 8
|
360 |
+
)
|
361 |
+
|
362 |
+
return values, normalizers
|
363 |
+
|
364 |
+
|
365 |
+
@torch.no_grad()
|
366 |
+
def correlation(args, dataset, model, criterion, tgt_dict, rank, world_size):
|
367 |
+
is_discrete_duration = dataset.discrete_dur
|
368 |
+
is_discrete_f0 = dataset.discrete_f0
|
369 |
+
|
370 |
+
f0_decoder = None
|
371 |
+
if is_discrete_f0:
|
372 |
+
assert dataset.discrete_f0
|
373 |
+
f0_decoder = Naive_F0_Decoder(
|
374 |
+
args.f0_discretization_bounds, dataset.config.f0_vq_n_units
|
375 |
+
).cuda()
|
376 |
+
|
377 |
+
if is_discrete_f0:
|
378 |
+
assert f0_decoder # correlation on tokens is meaningless
|
379 |
+
|
380 |
+
dataset = InferenceDataset(
|
381 |
+
dataset,
|
382 |
+
args.prefix_length,
|
383 |
+
filter_short=True,
|
384 |
+
presort_by_length=True,
|
385 |
+
min_length=args.min_length,
|
386 |
+
)
|
387 |
+
sampler = (
|
388 |
+
None
|
389 |
+
if world_size == 1
|
390 |
+
else DistributedSampler(
|
391 |
+
dataset, num_replicas=world_size, rank=rank, shuffle=False
|
392 |
+
)
|
393 |
+
)
|
394 |
+
dataloader = DataLoader(
|
395 |
+
dataset,
|
396 |
+
batch_size=1,
|
397 |
+
shuffle=False,
|
398 |
+
collate_fn=dataset.collater,
|
399 |
+
sampler=sampler,
|
400 |
+
)
|
401 |
+
|
402 |
+
Ts = args.T_token, args.T_duration, args.T_f0
|
403 |
+
decoder = TemperatureDecoder(
|
404 |
+
Ts, discrete_dur=is_discrete_duration, discrete_f0=is_discrete_f0
|
405 |
+
)
|
406 |
+
|
407 |
+
mean_dur_prefix, mean_dur_cont = [], []
|
408 |
+
mean_f0_prefix, mean_f0_cont = [], []
|
409 |
+
|
410 |
+
for batch in dataloader:
|
411 |
+
batch = explode_batch(batch, args.batch_explosion_rate)
|
412 |
+
batch = move_to_cuda(batch)
|
413 |
+
|
414 |
+
assert len(batch["prefix"]) == 1
|
415 |
+
|
416 |
+
if args.teacher_force_tokens:
|
417 |
+
autoregressive_steps = batch["target"].size(1) - args.prefix_length - 1
|
418 |
+
else:
|
419 |
+
autoregressive_steps = args.max_length - args.prefix_length # + max_shift?
|
420 |
+
|
421 |
+
if args.copy_target:
|
422 |
+
produced_durations, produced_f0 = batch["dur_target"], batch["f0_target"]
|
423 |
+
else:
|
424 |
+
_, produced_durations, produced_f0, outputs = do_sampling(
|
425 |
+
model,
|
426 |
+
batch,
|
427 |
+
tgt_dict.eos(),
|
428 |
+
decoder,
|
429 |
+
autoregressive_steps=autoregressive_steps,
|
430 |
+
teacher_force_tokens=args.teacher_force_tokens,
|
431 |
+
teacher_force_duration=args.teacher_force_duration,
|
432 |
+
teacher_force_f0=args.teacher_force_f0,
|
433 |
+
)
|
434 |
+
|
435 |
+
# first tokens actually correspond to BOS
|
436 |
+
produced_durations = produced_durations[:, 1:]
|
437 |
+
produced_f0 = produced_f0[:, 1:]
|
438 |
+
|
439 |
+
dur_target = batch["dur_target"]
|
440 |
+
if is_discrete_duration:
|
441 |
+
produced_durations = produced_durations.float()
|
442 |
+
dur_target = dur_target.float()
|
443 |
+
|
444 |
+
if is_discrete_f0:
|
445 |
+
produced_f0 = f0_decoder(produced_f0).squeeze(-1)
|
446 |
+
f0_target = batch["raw_f0"]
|
447 |
+
else:
|
448 |
+
f0_target = batch["f0_target"]
|
449 |
+
|
450 |
+
# prefix values
|
451 |
+
prefix = batch["prefix"][0]
|
452 |
+
dur_prefix_mean = dur_target[:, :prefix].sum(dim=-1) / (
|
453 |
+
(~batch["dur_mask"][:, :prefix]).sum(dim=-1)
|
454 |
+
)
|
455 |
+
|
456 |
+
non_voiced = f0_target[:, :prefix] == 0.0
|
457 |
+
f0_mask = batch["f0_mask"][:, :prefix].logical_or(non_voiced)
|
458 |
+
f0_prefix_mean = f0_target[:, :prefix].sum(dim=-1) / ((~f0_mask).sum(dim=-1))
|
459 |
+
|
460 |
+
# continuation values
|
461 |
+
dur_cont_mean = produced_durations[:, prefix:].sum(dim=-1) / (
|
462 |
+
(~batch["dur_mask"][:, prefix:]).sum(dim=-1)
|
463 |
+
)
|
464 |
+
|
465 |
+
non_voiced = produced_f0[:, prefix:] == 0.0
|
466 |
+
f0_mask = non_voiced
|
467 |
+
f0_cont_mean = produced_f0[:, prefix:].sum(dim=-1) / ((~f0_mask).sum(dim=-1))
|
468 |
+
|
469 |
+
assert not f0_cont_mean.isnan().any()
|
470 |
+
|
471 |
+
mean_dur_prefix.append(dur_prefix_mean.cpu())
|
472 |
+
mean_dur_cont.append(dur_cont_mean.cpu())
|
473 |
+
|
474 |
+
mean_f0_prefix.append(f0_prefix_mean.cpu())
|
475 |
+
mean_f0_cont.append(f0_cont_mean.cpu())
|
476 |
+
|
477 |
+
if args.debug and len(mean_dur_prefix) > 10:
|
478 |
+
break
|
479 |
+
|
480 |
+
mean_dur_prefix, mean_dur_cont = torch.cat(mean_dur_prefix), torch.cat(
|
481 |
+
mean_dur_cont
|
482 |
+
)
|
483 |
+
mean_f0_prefix, mean_f0_cont = torch.cat(mean_f0_prefix), torch.cat(mean_f0_cont)
|
484 |
+
|
485 |
+
return mean_dur_prefix, mean_dur_cont, mean_f0_prefix, mean_f0_cont
|
486 |
+
|
487 |
+
|
488 |
+
def main(rank, world_size, args):
|
489 |
+
start = time.time()
|
490 |
+
|
491 |
+
if world_size > 1:
|
492 |
+
torch.distributed.init_process_group(
|
493 |
+
backend="gloo", init_method="env://", world_size=world_size, rank=rank
|
494 |
+
)
|
495 |
+
torch.cuda.set_device(rank % torch.cuda.device_count())
|
496 |
+
|
497 |
+
raw_args = args
|
498 |
+
|
499 |
+
args = convert_namespace_to_omegaconf(args)
|
500 |
+
if args.common.seed is not None:
|
501 |
+
np.random.seed(args.common.seed)
|
502 |
+
utils.set_torch_seed(args.common.seed)
|
503 |
+
|
504 |
+
models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
|
505 |
+
[raw_args.path], arg_overrides={"data": args.task.data}
|
506 |
+
)
|
507 |
+
|
508 |
+
tgt_dict = task.target_dictionary
|
509 |
+
|
510 |
+
for model in models:
|
511 |
+
model.prepare_for_inference_(args)
|
512 |
+
model.cuda().eval()
|
513 |
+
if raw_args.fp16:
|
514 |
+
model = model.half()
|
515 |
+
model = models[0]
|
516 |
+
|
517 |
+
config = ExpressiveCodeDataConfig(args.task.data)
|
518 |
+
|
519 |
+
dataset = CodeDataset(
|
520 |
+
manifest=config.manifests[raw_args.eval_subset],
|
521 |
+
dictionary=task.source_dictionary,
|
522 |
+
dur_dictionary=task.source_duration_dictionary,
|
523 |
+
f0_dictionary=task.source_f0_dictionary,
|
524 |
+
config=config,
|
525 |
+
discrete_dur=task.cfg.discrete_duration,
|
526 |
+
discrete_f0=task.cfg.discrete_f0,
|
527 |
+
log_f0=task.cfg.log_f0,
|
528 |
+
normalize_f0_mean=task.cfg.normalize_f0_mean,
|
529 |
+
normalize_f0_std=task.cfg.normalize_f0_std,
|
530 |
+
interpolate_f0=task.cfg.interpolate_f0,
|
531 |
+
shifts=task.cfg.stream_shifts,
|
532 |
+
return_filename=True,
|
533 |
+
strip_filename=False,
|
534 |
+
return_continuous_f0=raw_args.dequantize_prosody,
|
535 |
+
)
|
536 |
+
|
537 |
+
if raw_args.filter_names:
|
538 |
+
dataset = FilterNamesDataset(dataset, raw_args.filter_names)
|
539 |
+
|
540 |
+
criterion = task.build_criterion(model_args.criterion)
|
541 |
+
|
542 |
+
name2metric = {
|
543 |
+
"continuation": continuation,
|
544 |
+
"teacher_force_everything": teacher_force_everything,
|
545 |
+
"correlation": correlation,
|
546 |
+
}
|
547 |
+
|
548 |
+
name2keys = {
|
549 |
+
"continuation": (
|
550 |
+
"Token BLEU3",
|
551 |
+
"Duration NLL",
|
552 |
+
"Duration MAE",
|
553 |
+
"F0 NLL",
|
554 |
+
"F0 MAE",
|
555 |
+
"F0 sum",
|
556 |
+
"F0 sum_sq",
|
557 |
+
"Dur sum",
|
558 |
+
"Dur sum_sq",
|
559 |
+
),
|
560 |
+
"teacher_force_everything": ("token_loss", "duration_loss", "f0_loss"),
|
561 |
+
"correlation": ("Duration corr", "F0 corr"),
|
562 |
+
}
|
563 |
+
metric_name = raw_args.metric
|
564 |
+
|
565 |
+
metric = name2metric[metric_name]
|
566 |
+
results = metric(raw_args, dataset, model, criterion, tgt_dict, rank, world_size)
|
567 |
+
|
568 |
+
values = None
|
569 |
+
|
570 |
+
if metric_name not in [
|
571 |
+
"correlation",
|
572 |
+
]:
|
573 |
+
values, normalizers = results
|
574 |
+
values = maybe_aggregate_normalize(values, normalizers, world_size)
|
575 |
+
elif metric_name == "correlation":
|
576 |
+
values = maybe_aggregate_correlations(results, world_size)
|
577 |
+
else:
|
578 |
+
assert False
|
579 |
+
|
580 |
+
assert values is not None
|
581 |
+
summary = dict(zip(name2keys[raw_args.metric], values.tolist()))
|
582 |
+
if metric_name == "continuation":
|
583 |
+
summary["F0 Std"] = np.sqrt(-summary["F0 sum"] ** 2 + summary["F0 sum_sq"])
|
584 |
+
summary["Dur Std"] = np.sqrt(-summary["Dur sum"] ** 2 + summary["Dur sum_sq"])
|
585 |
+
del summary["F0 sum"]
|
586 |
+
del summary["F0 sum_sq"]
|
587 |
+
del summary["Dur sum"]
|
588 |
+
del summary["Dur sum_sq"]
|
589 |
+
|
590 |
+
summary["metric"] = metric_name
|
591 |
+
|
592 |
+
if rank == 0:
|
593 |
+
print(summary)
|
594 |
+
if raw_args.wandb:
|
595 |
+
wandb_results(summary, raw_args)
|
596 |
+
print("# finished in ", time.time() - start, "seconds")
|
597 |
+
|
598 |
+
|
599 |
+
def wandb_results(summary, raw_args):
|
600 |
+
import wandb
|
601 |
+
|
602 |
+
run = wandb.init(
|
603 |
+
project=raw_args.wandb_project_name, tags=raw_args.wandb_tags.split(",")
|
604 |
+
)
|
605 |
+
run.config.metric = raw_args.metric
|
606 |
+
run.config.model = raw_args.path
|
607 |
+
run.config.data = raw_args.data
|
608 |
+
|
609 |
+
if raw_args.wandb_run_name:
|
610 |
+
run.name = raw_args.wandb_run_name
|
611 |
+
run.save()
|
612 |
+
|
613 |
+
wandb.log(summary)
|
614 |
+
wandb.finish()
|
615 |
+
|
616 |
+
|
617 |
+
def maybe_aggregate_normalize(values, normalizers, world_size):
|
618 |
+
if world_size > 1:
|
619 |
+
torch.distributed.barrier()
|
620 |
+
|
621 |
+
torch.distributed.all_reduce_multigpu([values])
|
622 |
+
torch.distributed.all_reduce_multigpu([normalizers])
|
623 |
+
|
624 |
+
return values / normalizers
|
625 |
+
|
626 |
+
|
627 |
+
def maybe_aggregate_correlations(results, world_size):
|
628 |
+
if world_size > 1:
|
629 |
+
output = [None for _ in range(world_size)]
|
630 |
+
torch.distributed.all_gather_object(output, results)
|
631 |
+
mean_dur_prefix, mean_dur_cont, mean_f0_prefix, mean_f0_cont = [
|
632 |
+
torch.cat([x[i] for x in output]) for i in range(4)
|
633 |
+
]
|
634 |
+
else:
|
635 |
+
mean_dur_prefix, mean_dur_cont, mean_f0_prefix, mean_f0_cont = results
|
636 |
+
|
637 |
+
corr_dur = scipy.stats.pearsonr(mean_dur_prefix.numpy(), mean_dur_cont.numpy())[0]
|
638 |
+
corr_f0 = scipy.stats.pearsonr(mean_f0_prefix.numpy(), mean_f0_cont.numpy())[0]
|
639 |
+
values = torch.tensor([corr_dur, corr_f0])
|
640 |
+
|
641 |
+
return values
|
642 |
+
|
643 |
+
|
644 |
+
def cli_main():
|
645 |
+
parser = options.get_interactive_generation_parser()
|
646 |
+
parser.add_argument(
|
647 |
+
"--prefix-length",
|
648 |
+
type=int,
|
649 |
+
default=1,
|
650 |
+
help="Prompt prefix length (including <s>)",
|
651 |
+
)
|
652 |
+
parser.add_argument(
|
653 |
+
"--duration-scale",
|
654 |
+
type=float,
|
655 |
+
default=1,
|
656 |
+
help="Multiply durations by the given scaler",
|
657 |
+
)
|
658 |
+
parser.add_argument(
|
659 |
+
"--debug", action="store_true", help="Process only the first batch"
|
660 |
+
)
|
661 |
+
parser.add_argument("--n_hypotheses", type=int, default=1)
|
662 |
+
parser.add_argument("--filter-names", type=str, default=None)
|
663 |
+
parser.add_argument(
|
664 |
+
"--max-length", type=int, default=200, help="Maximal produced length"
|
665 |
+
)
|
666 |
+
|
667 |
+
parser.add_argument("--teacher-force-tokens", action="store_true", default=False)
|
668 |
+
parser.add_argument("--teacher-force-duration", action="store_true", default=False)
|
669 |
+
parser.add_argument("--teacher-force-f0", action="store_true", default=False)
|
670 |
+
|
671 |
+
parser.add_argument("--copy-target", action="store_true", default=False)
|
672 |
+
parser.add_argument("--min-length", type=int, default=None)
|
673 |
+
parser.add_argument("--f0-discretization-bounds", type=str, default=None)
|
674 |
+
parser.add_argument("--dequantize-prosody", action="store_true")
|
675 |
+
parser.add_argument("--batch-explosion-rate", type=int, default=1)
|
676 |
+
|
677 |
+
parser.add_argument(
|
678 |
+
"--metric",
|
679 |
+
choices=["continuation", "teacher_force_everything", "correlation"],
|
680 |
+
required=True,
|
681 |
+
)
|
682 |
+
|
683 |
+
parser.add_argument("--wandb", action="store_true")
|
684 |
+
parser.add_argument("--wandb-project-name", type=str, default="eslm")
|
685 |
+
parser.add_argument("--wandb-tags", type=str, default="")
|
686 |
+
parser.add_argument("--wandb-run-name", type=str, default="")
|
687 |
+
|
688 |
+
parser.add_argument("--T-token", type=float, default=1.0)
|
689 |
+
parser.add_argument("--T-duration", type=float, default=1.0)
|
690 |
+
parser.add_argument("--T-f0", type=float, default=1.0)
|
691 |
+
|
692 |
+
parser.add_argument("--n-workers", type=int, default=1)
|
693 |
+
|
694 |
+
parser.add_argument(
|
695 |
+
"--eval-subset", type=str, default="valid", choices=["valid", "test"]
|
696 |
+
)
|
697 |
+
|
698 |
+
args = options.parse_args_and_arch(parser)
|
699 |
+
|
700 |
+
assert (
|
701 |
+
args.prefix_length >= 1
|
702 |
+
), "Prefix length includes bos token <s>, hence the minimum is 1."
|
703 |
+
assert args.temperature >= 0.0, "T must be non-negative!"
|
704 |
+
|
705 |
+
if args.dequantize_prosody:
|
706 |
+
assert args.f0_discretization_bounds
|
707 |
+
|
708 |
+
world_size = args.n_workers or torch.cuda.device_count()
|
709 |
+
if world_size > 1:
|
710 |
+
import random
|
711 |
+
|
712 |
+
mp.set_start_method("spawn", force=True)
|
713 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
714 |
+
os.environ["MASTER_PORT"] = str(random.randint(10_000, 50_000))
|
715 |
+
|
716 |
+
mp.spawn(
|
717 |
+
main,
|
718 |
+
nprocs=world_size,
|
719 |
+
args=(
|
720 |
+
world_size,
|
721 |
+
args,
|
722 |
+
),
|
723 |
+
join=True,
|
724 |
+
)
|
725 |
+
else:
|
726 |
+
main(rank=0, world_size=world_size, args=args)
|
727 |
+
|
728 |
+
|
729 |
+
if __name__ == "__main__":
|
730 |
+
cli_main()
|
fairseq/examples/textless_nlp/pgslm/generate_waveform.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import ast
|
7 |
+
import argparse
|
8 |
+
import json
|
9 |
+
import logging
|
10 |
+
from pathlib import Path
|
11 |
+
import soundfile as sf
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from fairseq import utils
|
17 |
+
from fairseq.models.text_to_speech.vocoder import CodeHiFiGANVocoder
|
18 |
+
|
19 |
+
|
20 |
+
logging.basicConfig()
|
21 |
+
logging.root.setLevel(logging.INFO)
|
22 |
+
logging.basicConfig(level=logging.INFO)
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
def dump_result(args, data, sample_id, pred_wav):
|
27 |
+
assert "audio" in data or args.results_path is not None
|
28 |
+
if args.results_path:
|
29 |
+
fname = Path(data["audio"]).name if "audio" in data else f"{sample_id}_pred.wav"
|
30 |
+
out_file = Path(args.results_path) / fname
|
31 |
+
|
32 |
+
sf.write(
|
33 |
+
out_file.as_posix(),
|
34 |
+
pred_wav.detach().cpu().numpy(),
|
35 |
+
args.sample_rate,
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
def load_data(in_file):
|
40 |
+
with open(in_file) as f:
|
41 |
+
data = [ast.literal_eval(line.strip()) for line in f]
|
42 |
+
|
43 |
+
return data
|
44 |
+
|
45 |
+
|
46 |
+
def get_f0_upsample_ratio(code_hop_size, f_hop_size):
|
47 |
+
ratio = (code_hop_size // 160) // (f_hop_size // 256) * 2
|
48 |
+
return ratio
|
49 |
+
|
50 |
+
|
51 |
+
def main(args):
|
52 |
+
logger.info(args)
|
53 |
+
|
54 |
+
use_cuda = torch.cuda.is_available() and not args.cpu
|
55 |
+
|
56 |
+
with open(args.vocoder_cfg) as f:
|
57 |
+
vocoder_cfg = json.load(f)
|
58 |
+
vocoder = CodeHiFiGANVocoder(args.vocoder, vocoder_cfg)
|
59 |
+
if use_cuda:
|
60 |
+
vocoder = vocoder.cuda()
|
61 |
+
|
62 |
+
data = load_data(args.in_file)
|
63 |
+
|
64 |
+
if args.results_path:
|
65 |
+
Path(args.results_path).mkdir(exist_ok=True, parents=True)
|
66 |
+
|
67 |
+
for i, d in tqdm(enumerate(data), total=len(data)):
|
68 |
+
code_key = "cpc_km100" if "cpc_km100" in d else "hubert"
|
69 |
+
code = list(map(int, d[code_key].split()))
|
70 |
+
|
71 |
+
x = {
|
72 |
+
"code": torch.LongTensor(code).view(1, -1),
|
73 |
+
"f0": torch.Tensor(d["f0"]).view(1, -1),
|
74 |
+
}
|
75 |
+
|
76 |
+
f0_up_ratio = get_f0_upsample_ratio(
|
77 |
+
vocoder_cfg["code_hop_size"], vocoder_cfg["hop_size"]
|
78 |
+
)
|
79 |
+
if f0_up_ratio > 1:
|
80 |
+
bsz, cond_length = x["f0"].size()
|
81 |
+
x["f0"] = x["f0"].unsqueeze(2).repeat(1, 1, f0_up_ratio).view(bsz, -1)
|
82 |
+
|
83 |
+
x = utils.move_to_cuda(x) if use_cuda else x
|
84 |
+
wav = vocoder(x)
|
85 |
+
dump_result(args, d, i, wav)
|
86 |
+
|
87 |
+
|
88 |
+
def cli_main():
|
89 |
+
parser = argparse.ArgumentParser()
|
90 |
+
parser.add_argument(
|
91 |
+
"--in-file",
|
92 |
+
type=str,
|
93 |
+
required=True,
|
94 |
+
help="Input file following the same format of the output from sample.py ('f0' and 'cpc_km100/hubert' are required fields)",
|
95 |
+
)
|
96 |
+
parser.add_argument(
|
97 |
+
"--vocoder", type=str, required=True, help="path to the vocoder"
|
98 |
+
)
|
99 |
+
parser.add_argument(
|
100 |
+
"--vocoder-cfg",
|
101 |
+
type=str,
|
102 |
+
required=True,
|
103 |
+
help="path to the vocoder config",
|
104 |
+
)
|
105 |
+
parser.add_argument("--sample-rate", type=int, default=16_000)
|
106 |
+
parser.add_argument(
|
107 |
+
"--results-path",
|
108 |
+
type=str,
|
109 |
+
default=None,
|
110 |
+
help="Output directory. If not set, the audios will be stored following the 'audio' field specified in the input file.",
|
111 |
+
)
|
112 |
+
parser.add_argument("--cpu", action="store_true", help="run on CPU")
|
113 |
+
|
114 |
+
args = parser.parse_args()
|
115 |
+
|
116 |
+
main(args)
|
117 |
+
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
cli_main()
|
fairseq/examples/textless_nlp/pgslm/inference_dataset.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
class InferenceDataset:
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
dataset,
|
14 |
+
prefix,
|
15 |
+
only_prefix=True,
|
16 |
+
presort_by_length=True,
|
17 |
+
filter_short=False,
|
18 |
+
min_length=None,
|
19 |
+
):
|
20 |
+
self.dataset = dataset
|
21 |
+
self.collater = self.dataset.collater
|
22 |
+
self.prefix = prefix
|
23 |
+
self.only_prefix = only_prefix
|
24 |
+
self.filter_short = filter_short
|
25 |
+
|
26 |
+
self.remapping = list(range(len(self.dataset)))
|
27 |
+
if min_length:
|
28 |
+
assert min_length >= prefix + 1
|
29 |
+
|
30 |
+
length_thr = prefix + 1 if not min_length else min_length
|
31 |
+
|
32 |
+
if filter_short:
|
33 |
+
self.remapping = list(
|
34 |
+
filter(
|
35 |
+
lambda i: self.dataset[i]["dur_source"].sum() > length_thr,
|
36 |
+
self.remapping,
|
37 |
+
)
|
38 |
+
)
|
39 |
+
print(
|
40 |
+
f"# the initial dataset of {len(self.dataset)} examples became {len(self.remapping)} after filtering"
|
41 |
+
f" examples shorter than {length_thr} (in duration units)"
|
42 |
+
)
|
43 |
+
|
44 |
+
if presort_by_length:
|
45 |
+
lengths = {index: dataset.size(index) for index in self.remapping}
|
46 |
+
self.remapping.sort(key=lambda i: lengths[i])
|
47 |
+
|
48 |
+
@property
|
49 |
+
def pads(self):
|
50 |
+
return self.dataset.pads
|
51 |
+
|
52 |
+
def __len__(self):
|
53 |
+
return len(self.remapping)
|
54 |
+
|
55 |
+
def original_size(self, k):
|
56 |
+
k = self.remapping[k]
|
57 |
+
return self.dataset.size(k)
|
58 |
+
|
59 |
+
def __getitem__(self, k):
|
60 |
+
k = self.remapping[k]
|
61 |
+
channels = self.dataset[k]
|
62 |
+
|
63 |
+
if self.prefix and self.only_prefix:
|
64 |
+
dur_channel = channels["dur_source"]
|
65 |
+
assert dur_channel.sum() >= self.prefix
|
66 |
+
|
67 |
+
token_times = dur_channel.cumsum(dim=-1)
|
68 |
+
cut_after = torch.searchsorted(token_times, torch.tensor(self.prefix))
|
69 |
+
|
70 |
+
r = {}
|
71 |
+
for channel_name, value in channels.items():
|
72 |
+
if isinstance(value, torch.Tensor) and "source" in channel_name:
|
73 |
+
# if self.filter_short: assert value.size(0) >= self.prefix
|
74 |
+
r[channel_name] = value[: cut_after + 1]
|
75 |
+
else:
|
76 |
+
r[channel_name] = value
|
77 |
+
|
78 |
+
r["prefix"] = cut_after + 1
|
79 |
+
else:
|
80 |
+
r = channels
|
81 |
+
|
82 |
+
return r
|
83 |
+
|
84 |
+
|
85 |
+
def explode_batch(batch, times):
|
86 |
+
if times == 1:
|
87 |
+
return batch
|
88 |
+
|
89 |
+
new_batch = {}
|
90 |
+
|
91 |
+
for key, value in batch.items():
|
92 |
+
if isinstance(value, torch.Tensor):
|
93 |
+
assert value.size(0) == 1
|
94 |
+
new_batch[key] = torch.cat([value] * times)
|
95 |
+
elif key in ["ntokens", "nsentences"]:
|
96 |
+
new_batch[key] = value * times
|
97 |
+
elif key in ["prefix", "filename"]:
|
98 |
+
new_batch[key] = value
|
99 |
+
elif key == "net_input":
|
100 |
+
new_batch[key] = explode_batch(value, times)
|
101 |
+
else:
|
102 |
+
assert False, key
|
103 |
+
return new_batch
|
fairseq/examples/textless_nlp/pgslm/naive_decoder.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import warnings
|
8 |
+
|
9 |
+
|
10 |
+
class Naive_F0_Decoder(torch.nn.Module):
|
11 |
+
def __init__(self, bounds_path, n_units=32):
|
12 |
+
super().__init__()
|
13 |
+
|
14 |
+
bounds = torch.load(bounds_path)
|
15 |
+
bounds = torch.from_numpy(bounds[n_units])
|
16 |
+
assert bounds.ndim == 1
|
17 |
+
|
18 |
+
pad = torch.tensor([-5.0, -5.0]) # bos, eos, pad are in the dictionary
|
19 |
+
centers = torch.cat(
|
20 |
+
[bounds[0:1], 0.5 * (bounds[1:] + bounds[:-1]), bounds[-1:], pad[:]]
|
21 |
+
)
|
22 |
+
|
23 |
+
self.embedding = torch.nn.Embedding.from_pretrained(
|
24 |
+
centers.unsqueeze(-1), freeze=True
|
25 |
+
)
|
26 |
+
self.max_n = self.embedding.weight.numel()
|
27 |
+
|
28 |
+
def forward(self, discrete_f0: torch.Tensor):
|
29 |
+
in_bounds = (0 <= discrete_f0).all() and (discrete_f0 < self.max_n).all()
|
30 |
+
if not in_bounds:
|
31 |
+
warnings.warn(
|
32 |
+
f"F0 contains some weird outputs: discrete_f0.max().item()={discrete_f0.max().item()} discrete_f0.min().item()={discrete_f0.min().item()}; "
|
33 |
+
f"while we have embeddings for {self.max_n} values. "
|
34 |
+
"Assuming this is a no-prosody model -- but be careful!"
|
35 |
+
)
|
36 |
+
|
37 |
+
mask = discrete_f0 >= self.max_n
|
38 |
+
discrete_f0 = discrete_f0.masked_fill(mask, self.max_n - 1)
|
39 |
+
|
40 |
+
return self.embedding(discrete_f0).squeeze(-1)
|
fairseq/examples/textless_nlp/pgslm/prepare_dataset.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from multiprocessing import Pool
|
7 |
+
|
8 |
+
import os
|
9 |
+
from collections import defaultdict
|
10 |
+
from itertools import starmap
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from npy_append_array import NpyAppendArray
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from data_utils import dump_speaker_f0_stat, F0Stat, load_f0
|
17 |
+
from fairseq.data.codedataset import (
|
18 |
+
ExpressiveCodeDataConfig,
|
19 |
+
parse_manifest,
|
20 |
+
F0_FRAME_SPACE,
|
21 |
+
align_f0_to_durations,
|
22 |
+
)
|
23 |
+
from fairseq.tasks.speech_ulm_task import UnitDictionary
|
24 |
+
|
25 |
+
|
26 |
+
def load_meta(meta_path, split):
|
27 |
+
config = ExpressiveCodeDataConfig(meta_path)
|
28 |
+
manifest_path = config.manifests[split]
|
29 |
+
dictionary = UnitDictionary(n_units=config.n_units)
|
30 |
+
audio_paths, codes, durs, speakers = parse_manifest(manifest_path, dictionary)
|
31 |
+
return config, audio_paths, codes, durs, speakers
|
32 |
+
|
33 |
+
|
34 |
+
def _align_f0(f0, dur, ratio, frm_tol=5):
|
35 |
+
if f0 is None:
|
36 |
+
seg_f0 = torch.zeros_like(dur, dtype=torch.float)
|
37 |
+
else:
|
38 |
+
seg_f0 = align_f0_to_durations(f0, dur, ratio, tol=frm_tol * ratio)
|
39 |
+
return seg_f0.numpy() # try a hacky stuff
|
40 |
+
|
41 |
+
|
42 |
+
def align_f0(path_to_f0, audio_paths, durs, ratio, mp=False):
|
43 |
+
chunk_size = 2000
|
44 |
+
num_procs = 40
|
45 |
+
iterable = ((path_to_f0[p], d, ratio) for p, d in zip(audio_paths, durs))
|
46 |
+
|
47 |
+
seg_f0s = []
|
48 |
+
if mp:
|
49 |
+
with Pool(num_procs) as pool:
|
50 |
+
iterator = tqdm(
|
51 |
+
pool.istarmap(_align_f0, iterable, chunk_size),
|
52 |
+
desc="align f0",
|
53 |
+
total=len(durs),
|
54 |
+
)
|
55 |
+
for seg_f0 in iterator:
|
56 |
+
seg_f0s.append(torch.from_numpy(seg_f0).float())
|
57 |
+
else:
|
58 |
+
iterator = tqdm(starmap(_align_f0, iterable), desc="align f0", total=len(durs))
|
59 |
+
for seg_f0 in iterator:
|
60 |
+
seg_f0s.append(torch.from_numpy(seg_f0).float())
|
61 |
+
|
62 |
+
return seg_f0s
|
63 |
+
|
64 |
+
|
65 |
+
def prepare_seg_data(config, audio_paths, codes, durs, speakers, path_to_f0):
|
66 |
+
ratio = config.code_hop_size / (config.sampling_rate * F0_FRAME_SPACE)
|
67 |
+
seg_f0s = align_f0(path_to_f0, audio_paths, durs, ratio)
|
68 |
+
data = {
|
69 |
+
"codes": codes,
|
70 |
+
"duration": durs,
|
71 |
+
"f0": seg_f0s,
|
72 |
+
"speaker": speakers,
|
73 |
+
"path": audio_paths,
|
74 |
+
}
|
75 |
+
return data
|
76 |
+
|
77 |
+
|
78 |
+
def dump_seg_data(data, out_prefix):
|
79 |
+
key_targs = {
|
80 |
+
"codes": f"{out_prefix}.code.npy",
|
81 |
+
"duration": f"{out_prefix}.dur.npy",
|
82 |
+
"f0": f"{out_prefix}.f0.npy",
|
83 |
+
}
|
84 |
+
for key, targ in key_targs.items():
|
85 |
+
assert not os.path.exists(targ)
|
86 |
+
npaa = NpyAppendArray(targ)
|
87 |
+
for utt_data in tqdm(data[key], desc=f"dumping {key}"):
|
88 |
+
npaa.append(utt_data.numpy())
|
89 |
+
|
90 |
+
assert not os.path.exists(f"{out_prefix}.path.txt")
|
91 |
+
with open(f"{out_prefix}.path.txt", "w") as f:
|
92 |
+
for x in data["path"]:
|
93 |
+
f.write(f"{str(x)}\n")
|
94 |
+
|
95 |
+
assert not os.path.exists(f"{out_prefix}.leng.txt")
|
96 |
+
with open(f"{out_prefix}.leng.txt", "w") as f:
|
97 |
+
for x in data["codes"]:
|
98 |
+
f.write(f"{len(x)}\n")
|
99 |
+
|
100 |
+
assert not os.path.exists(f"{out_prefix}.speaker.txt")
|
101 |
+
with open(f"{out_prefix}.speaker.txt", "w") as f:
|
102 |
+
for x in data["speaker"]:
|
103 |
+
f.write(f"{str(x)}\n")
|
104 |
+
|
105 |
+
print(f"wrote to files with prefix {out_prefix}")
|
106 |
+
|
107 |
+
|
108 |
+
def main(meta_path, f0_dir, splits, nshards_list):
|
109 |
+
speaker_to_stat = defaultdict(F0Stat)
|
110 |
+
if len(nshards_list) == 1:
|
111 |
+
nshards_list = nshards_list * len(splits)
|
112 |
+
else:
|
113 |
+
assert len(nshards_list) == len(splits)
|
114 |
+
|
115 |
+
for split, nshards in zip(splits, nshards_list):
|
116 |
+
config, audio_paths, codes, durs, speakers = load_meta(meta_path, split)
|
117 |
+
path_to_f0 = load_f0(f"{f0_dir}/{split}", nshards)
|
118 |
+
|
119 |
+
# segment-level data
|
120 |
+
data = prepare_seg_data(config, audio_paths, codes, durs, speakers, path_to_f0)
|
121 |
+
dump_seg_data(data, config.manifests[split])
|
122 |
+
|
123 |
+
# speaker f0
|
124 |
+
for audio_path, speaker in tqdm(zip(audio_paths, speakers)):
|
125 |
+
f0 = path_to_f0[audio_path]
|
126 |
+
speaker_to_stat[speaker].update(f0)
|
127 |
+
dump_speaker_f0_stat(speaker_to_stat, config.manifests[split])
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == "__main__":
|
131 |
+
import argparse
|
132 |
+
|
133 |
+
parser = argparse.ArgumentParser()
|
134 |
+
parser.add_argument("meta_path")
|
135 |
+
parser.add_argument("f0_dir", help="out_dir from preprocess_f0")
|
136 |
+
parser.add_argument("--splits", nargs="+", default=["train", "valid"])
|
137 |
+
parser.add_argument(
|
138 |
+
"--nshards_list", type=int, nargs="+", default=[20], help="number of f0 shards"
|
139 |
+
)
|
140 |
+
args = parser.parse_args()
|
141 |
+
print(args)
|
142 |
+
|
143 |
+
main(**vars(args))
|
fairseq/examples/textless_nlp/pgslm/preprocess_f0.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
+
from data_utils import load_audio_path
|
10 |
+
from fairseq.data.codedataset import get_f0_by_filename
|
11 |
+
|
12 |
+
|
13 |
+
def process_one(path, sr):
|
14 |
+
"""
|
15 |
+
Args:
|
16 |
+
path: audio file path
|
17 |
+
sr: sampling rate
|
18 |
+
"""
|
19 |
+
try:
|
20 |
+
# YAAPT throws errors in some rare cases
|
21 |
+
f0 = get_f0_by_filename(path, sr)
|
22 |
+
except Exception as e:
|
23 |
+
print(
|
24 |
+
f"WARNING: error when processing {path}. set f0 to zero. original error message:\n{e}"
|
25 |
+
)
|
26 |
+
f0 = None
|
27 |
+
return f0
|
28 |
+
|
29 |
+
|
30 |
+
def main(file_path, out_dir, nshards, rank, sampling_rate):
|
31 |
+
# load data
|
32 |
+
audio_paths = load_audio_path(file_path)
|
33 |
+
|
34 |
+
# shard
|
35 |
+
assert nshards <= len(audio_paths) and nshards > 0
|
36 |
+
shard_size = len(audio_paths) / nshards
|
37 |
+
s = int(round((rank - 1) * shard_size))
|
38 |
+
e = int(round(rank * shard_size))
|
39 |
+
audio_paths = audio_paths[s:e]
|
40 |
+
|
41 |
+
# process
|
42 |
+
path_to_f0 = {}
|
43 |
+
for i, audio_path in enumerate(tqdm(audio_paths)):
|
44 |
+
f0 = process_one(audio_path, sampling_rate)
|
45 |
+
path_to_f0[audio_path] = f0
|
46 |
+
print(f"finished processing {len(path_to_f0)} utterances ({s}-{e})")
|
47 |
+
|
48 |
+
f0_path = f"{out_dir}/f0_{rank}_{nshards}.pt"
|
49 |
+
os.makedirs(out_dir, exist_ok=True)
|
50 |
+
torch.save(path_to_f0, f0_path)
|
51 |
+
print(f"saved to {f0_path}")
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
import argparse
|
56 |
+
|
57 |
+
parser = argparse.ArgumentParser()
|
58 |
+
parser.add_argument("file_path")
|
59 |
+
parser.add_argument("out_dir")
|
60 |
+
parser.add_argument("--nshards", type=int, default=20)
|
61 |
+
parser.add_argument("--rank", type=int, default=1)
|
62 |
+
parser.add_argument("--sampling_rate", type=int, default=16000)
|
63 |
+
args = parser.parse_args()
|
64 |
+
|
65 |
+
main(**vars(args))
|
fairseq/examples/textless_nlp/pgslm/quantize_f0.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from collections import defaultdict
|
7 |
+
from functools import partial
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
from data_utils import dump_speaker_f0_stat, F0Stat, load_audio_path, load_f0
|
14 |
+
|
15 |
+
|
16 |
+
def load_speaker(path):
|
17 |
+
speakers = []
|
18 |
+
with open(path) as f:
|
19 |
+
for line in f.readlines():
|
20 |
+
sample = eval(line.strip())
|
21 |
+
assert "speaker" in sample
|
22 |
+
speakers.append(sample["speaker"])
|
23 |
+
return speakers
|
24 |
+
|
25 |
+
|
26 |
+
def quantize_f0(speaker_to_f0, f0_stats, nbins, normalize, log):
|
27 |
+
f0_all = []
|
28 |
+
for speaker, f0 in speaker_to_f0.items():
|
29 |
+
f0 = f0.raw_data
|
30 |
+
if log:
|
31 |
+
f0 = f0.log()
|
32 |
+
mean = f0_stats[speaker]["logf0_mean"] if log else f0_stats[speaker]["f0_mean"]
|
33 |
+
std = f0_stats[speaker]["logf0_std"] if log else f0_stats[speaker]["f0_std"]
|
34 |
+
if normalize == "mean":
|
35 |
+
f0 = f0 - mean
|
36 |
+
elif normalize == "meanstd":
|
37 |
+
f0 = (f0 - mean) / std
|
38 |
+
f0_all.extend(f0.tolist())
|
39 |
+
|
40 |
+
hist, bin_x = np.histogram(f0_all, 100000)
|
41 |
+
cum_hist = np.cumsum(hist) / len(f0_all) * 100
|
42 |
+
|
43 |
+
f0_bin = {}
|
44 |
+
for num_bin in nbins:
|
45 |
+
bin_offset = []
|
46 |
+
bin_size = 100 / num_bin
|
47 |
+
threshold = bin_size
|
48 |
+
for i in range(num_bin - 1):
|
49 |
+
index = (np.abs(cum_hist - threshold)).argmin()
|
50 |
+
bin_offset.append(bin_x[index])
|
51 |
+
threshold += bin_size
|
52 |
+
f0_bin[num_bin] = np.array(bin_offset)
|
53 |
+
|
54 |
+
return f0_bin
|
55 |
+
|
56 |
+
|
57 |
+
def main(file_path, f0_dir, out_dir, out_prefix, nbins, nshards, normalize, log):
|
58 |
+
audio_paths = load_audio_path(file_path)
|
59 |
+
path_to_f0 = load_f0(f0_dir, nshards)
|
60 |
+
|
61 |
+
speakers = load_speaker(file_path)
|
62 |
+
speaker_to_f0 = defaultdict(partial(F0Stat, True))
|
63 |
+
|
64 |
+
# speaker f0 stats
|
65 |
+
for audio_path, speaker in tqdm(zip(audio_paths, speakers)):
|
66 |
+
f0 = path_to_f0[audio_path]
|
67 |
+
speaker_to_f0[speaker].update(f0)
|
68 |
+
f0_stats = dump_speaker_f0_stat(speaker_to_f0, f"{out_dir}/{out_prefix}")
|
69 |
+
|
70 |
+
# quantize
|
71 |
+
f0_bin = quantize_f0(speaker_to_f0, f0_stats, nbins, normalize, log)
|
72 |
+
log_suffix = "_log" if log else ""
|
73 |
+
f0_bin_out_file = f"{out_dir}/{out_prefix}_{normalize}_norm{log_suffix}_f0_bin.th"
|
74 |
+
torch.save(f0_bin, f0_bin_out_file)
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == "__main__":
|
78 |
+
import argparse
|
79 |
+
|
80 |
+
parser = argparse.ArgumentParser()
|
81 |
+
parser.add_argument("file_path")
|
82 |
+
parser.add_argument("f0_dir", help="out_dir from preprocess_f0")
|
83 |
+
parser.add_argument("out_dir")
|
84 |
+
parser.add_argument("out_prefix")
|
85 |
+
parser.add_argument("--nbins", nargs="+", type=int, default=[32])
|
86 |
+
parser.add_argument("--nshards", type=int, default=20, help="number of f0 shards")
|
87 |
+
parser.add_argument(
|
88 |
+
"--normalize", type=str, choices=["meanstd", "mean", "none"], default="mean"
|
89 |
+
)
|
90 |
+
parser.add_argument("--log", action="store_true")
|
91 |
+
args = parser.parse_args()
|
92 |
+
print(args)
|
93 |
+
|
94 |
+
main(**vars(args))
|