diff --git a/fairseq/examples/hubert/tests/sample.base.L9.npy b/fairseq/examples/hubert/tests/sample.base.L9.npy new file mode 100644 index 0000000000000000000000000000000000000000..7a2b109228036a1de485a3bb942e1df869999c55 --- /dev/null +++ b/fairseq/examples/hubert/tests/sample.base.L9.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b44dc3a0519f6adb670a5da7410c1405f4fdb0e4866e5fc4983b136480212f34 +size 1831104 diff --git a/fairseq/examples/hubert/tests/sample.large.L20.npy b/fairseq/examples/hubert/tests/sample.large.L20.npy new file mode 100644 index 0000000000000000000000000000000000000000..f5a7dcf1a9ccaa6d8e78a75e3ed749a05775a4e5 --- /dev/null +++ b/fairseq/examples/hubert/tests/sample.large.L20.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c82839134cc2340355eb49b41e7e3517e8e9dfdfaa6fc28f0464cc8ae9569ee +size 2441408 diff --git a/fairseq/examples/simultaneous_translation/__pycache__/__init__.cpython-310.pyc b/fairseq/examples/simultaneous_translation/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..859b7aa20ee471b66aec49c3c6c29986a2602404 Binary files /dev/null and b/fairseq/examples/simultaneous_translation/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/examples/simultaneous_translation/utils/__pycache__/__init__.cpython-310.pyc b/fairseq/examples/simultaneous_translation/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09b2407a5f228a28677eb600ce91c4966824ba81 Binary files /dev/null and b/fairseq/examples/simultaneous_translation/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/examples/simultaneous_translation/utils/__pycache__/functions.cpython-310.pyc b/fairseq/examples/simultaneous_translation/utils/__pycache__/functions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e6f44ce06cc738ef831f10b89450670fce7df67 Binary files /dev/null and b/fairseq/examples/simultaneous_translation/utils/__pycache__/functions.cpython-310.pyc differ diff --git a/fairseq/examples/simultaneous_translation/utils/__pycache__/p_choose_strategy.cpython-310.pyc b/fairseq/examples/simultaneous_translation/utils/__pycache__/p_choose_strategy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9182280a2ba539fa18888b4d0b0d606e4b9f9266 Binary files /dev/null and b/fairseq/examples/simultaneous_translation/utils/__pycache__/p_choose_strategy.cpython-310.pyc differ diff --git a/fairseq/examples/speech_recognition/README.md b/fairseq/examples/speech_recognition/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5f9b27880e439dc1803db7f6f0765edaefeb7097 --- /dev/null +++ b/fairseq/examples/speech_recognition/README.md @@ -0,0 +1,87 @@ +### 2021 Update: We are merging this example into the [S2T framework](../speech_to_text), which supports more generic speech-to-text tasks (e.g. speech translation) and more flexible data processing pipelines. Please stay tuned. + +# Speech Recognition +`examples/speech_recognition` is implementing ASR task in Fairseq, along with needed features, datasets, models and loss functions to train and infer model described in [Transformers with convolutional context for ASR (Abdelrahman Mohamed et al., 2019)](https://arxiv.org/abs/1904.11660). + + +## Additional dependencies +On top of main fairseq dependencies there are couple more additional requirements. + +1) Please follow the instructions to install [torchaudio](https://github.com/pytorch/audio). This is required to compute audio fbank features. +2) [Sclite](http://www1.icsi.berkeley.edu/Speech/docs/sctk-1.2/sclite.htm#sclite_name_0) is used to measure WER. Sclite can be downloaded and installed from source from sctk package [here](http://www.openslr.org/4/). Training and inference doesn't require Sclite dependency. +3) [sentencepiece](https://github.com/google/sentencepiece) is required in order to create dataset with word-piece targets. + +## Preparing librispeech data +``` +./examples/speech_recognition/datasets/prepare-librispeech.sh $DIR_TO_SAVE_RAW_DATA $DIR_FOR_PREPROCESSED_DATA +``` + +## Training librispeech data +``` +python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 80 --task speech_recognition --arch vggtransformer_2 --optimizer adadelta --lr 1.0 --adadelta-eps 1e-8 --adadelta-rho 0.95 --clip-norm 10.0 --max-tokens 5000 --log-format json --log-interval 1 --criterion cross_entropy_acc --user-dir examples/speech_recognition/ +``` + +## Inference for librispeech +`$SET` can be `test_clean` or `test_other` +Any checkpoint in `$MODEL_PATH` can be selected. In this example we are working with `checkpoint_last.pt` +``` +python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --max-tokens 25000 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --beam 20 --results-path $RES_DIR --batch-size 40 --gen-subset $SET --user-dir examples/speech_recognition/ +``` + +## Inference for librispeech +``` +sclite -r ${RES_DIR}/ref.word-checkpoint_last.pt-${SET}.txt -h ${RES_DIR}/hypo.word-checkpoint_last.pt-${SET}.txt -i rm -o all stdout > $RES_REPORT +``` +`Sum/Avg` row from first table of the report has WER + +## Using flashlight (previously called [wav2letter](https://github.com/facebookresearch/wav2letter)) components +[flashlight](https://github.com/facebookresearch/flashlight) now has integration with fairseq. Currently this includes: + +* AutoSegmentationCriterion (ASG) +* flashlight-style Conv/GLU model +* flashlight's beam search decoder + +To use these, follow the instructions on [this page](https://github.com/flashlight/flashlight/tree/e16682fa32df30cbf675c8fe010f929c61e3b833/bindings/python) to install python bindings. **Flashlight v0.3.2** must be used to install the bindings. Running: +``` +git clone --branch v0.3.2 https://github.com/flashlight/flashlight +``` +will properly clone and check out this version. + +## Training librispeech data (flashlight style, Conv/GLU + ASG loss) +Training command: +``` +python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 100 --task speech_recognition --arch w2l_conv_glu_enc --batch-size 4 --optimizer sgd --lr 0.3,0.8 --momentum 0.8 --clip-norm 0.2 --max-tokens 50000 --log-format json --log-interval 100 --num-workers 0 --sentence-avg --criterion asg_loss --asg-transitions-init 5 --max-replabel 2 --linseg-updates 8789 --user-dir examples/speech_recognition +``` + +Note that ASG loss currently doesn't do well with word-pieces. You should prepare a dataset with character targets by setting `nbpe=31` in `prepare-librispeech.sh`. + +## Inference for librispeech (flashlight decoder, n-gram LM) +Inference command: +``` +python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder kenlm --kenlm-model $KENLM_MODEL_PATH --lexicon $LEXICON_PATH --beam 200 --beam-threshold 15 --lm-weight 1.5 --word-score 1.5 --sil-weight -0.3 --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition +``` + +`$KENLM_MODEL_PATH` should be a standard n-gram language model file. `$LEXICON_PATH` should be a flashlight-style lexicon (list of known words and their spellings). For ASG inference, a lexicon line should look like this (note the repetition labels): +``` +doorbell D O 1 R B E L 1 ▁ +``` +For CTC inference with word-pieces, repetition labels are not used and the lexicon should have most common spellings for each word (one can use sentencepiece's `NBestEncodeAsPieces` for this): +``` +doorbell ▁DOOR BE LL +doorbell ▁DOOR B E LL +doorbell ▁DO OR BE LL +doorbell ▁DOOR B EL L +doorbell ▁DOOR BE L L +doorbell ▁DO OR B E LL +doorbell ▁DOOR B E L L +doorbell ▁DO OR B EL L +doorbell ▁DO O R BE LL +doorbell ▁DO OR BE L L +``` +Lowercase vs. uppercase matters: the *word* should match the case of the n-gram language model (i.e. `$KENLM_MODEL_PATH`), while the *spelling* should match the case of the token dictionary (i.e. `$DIR_FOR_PREPROCESSED_DATA/dict.txt`). + +## Inference for librispeech (flashlight decoder, viterbi only) +Inference command: +``` +python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder viterbi --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition +``` diff --git a/fairseq/examples/speech_recognition/__init__.py b/fairseq/examples/speech_recognition/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0278f6a27340c7ff7e207d09348483d1b0d3a100 --- /dev/null +++ b/fairseq/examples/speech_recognition/__init__.py @@ -0,0 +1 @@ +from . import criterions, models, tasks # noqa diff --git a/fairseq/examples/speech_recognition/criterions/ASG_loss.py b/fairseq/examples/speech_recognition/criterions/ASG_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..41f50bbd70388ce723f2d316d4e9776bcd6be3c9 --- /dev/null +++ b/fairseq/examples/speech_recognition/criterions/ASG_loss.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from examples.speech_recognition.data.replabels import pack_replabels +from fairseq import utils +from fairseq.criterions import FairseqCriterion, register_criterion + + +@register_criterion("asg_loss") +class ASGCriterion(FairseqCriterion): + @staticmethod + def add_args(parser): + group = parser.add_argument_group("ASG Loss") + group.add_argument( + "--asg-transitions-init", + help="initial diagonal value of transition matrix", + type=float, + default=0.0, + ) + group.add_argument( + "--max-replabel", help="maximum # of replabels", type=int, default=2 + ) + group.add_argument( + "--linseg-updates", + help="# of training updates to use LinSeg initialization", + type=int, + default=0, + ) + group.add_argument( + "--hide-linseg-messages", + help="hide messages about LinSeg initialization", + action="store_true", + ) + + def __init__( + self, + task, + silence_token, + asg_transitions_init, + max_replabel, + linseg_updates, + hide_linseg_messages, + ): + from flashlight.lib.sequence.criterion import ASGLoss, CriterionScaleMode + + super().__init__(task) + self.tgt_dict = task.target_dictionary + self.eos = self.tgt_dict.eos() + self.silence = ( + self.tgt_dict.index(silence_token) + if silence_token in self.tgt_dict + else None + ) + self.max_replabel = max_replabel + + num_labels = len(self.tgt_dict) + self.asg = ASGLoss(num_labels, scale_mode=CriterionScaleMode.TARGET_SZ_SQRT) + self.asg.trans = torch.nn.Parameter( + asg_transitions_init * torch.eye(num_labels), requires_grad=True + ) + + self.linseg_progress = torch.nn.Parameter( + torch.tensor([0], dtype=torch.int), requires_grad=False + ) + self.linseg_maximum = linseg_updates + self.linseg_message_state = "none" if hide_linseg_messages else "start" + + @classmethod + def build_criterion(cls, args, task): + return cls( + task, + args.silence_token, + args.asg_transitions_init, + args.max_replabel, + args.linseg_updates, + args.hide_linseg_messages, + ) + + def linseg_step(self): + if not self.training: + return False + if self.linseg_progress.item() < self.linseg_maximum: + if self.linseg_message_state == "start": + print("| using LinSeg to initialize ASG") + self.linseg_message_state = "finish" + self.linseg_progress.add_(1) + return True + elif self.linseg_message_state == "finish": + print("| finished LinSeg initialization") + self.linseg_message_state = "none" + return False + + def replace_eos_with_silence(self, tgt): + if tgt[-1] != self.eos: + return tgt + elif self.silence is None or (len(tgt) > 1 and tgt[-2] == self.silence): + return tgt[:-1] + else: + return tgt[:-1] + [self.silence] + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + + net_output = model(**sample["net_input"]) + emissions = net_output["encoder_out"].transpose(0, 1).contiguous() + B = emissions.size(0) + T = emissions.size(1) + device = emissions.device + + target = torch.IntTensor(B, T) + target_size = torch.IntTensor(B) + using_linseg = self.linseg_step() + + for b in range(B): + initial_target_size = sample["target_lengths"][b].item() + if initial_target_size == 0: + raise ValueError("target size cannot be zero") + + tgt = sample["target"][b, :initial_target_size].tolist() + tgt = self.replace_eos_with_silence(tgt) + tgt = pack_replabels(tgt, self.tgt_dict, self.max_replabel) + tgt = tgt[:T] + + if using_linseg: + tgt = [tgt[t * len(tgt) // T] for t in range(T)] + + target[b][: len(tgt)] = torch.IntTensor(tgt) + target_size[b] = len(tgt) + + loss = self.asg.forward(emissions, target.to(device), target_size.to(device)) + + if reduce: + loss = torch.sum(loss) + + sample_size = ( + sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": utils.item(loss.data) if reduce else loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + return loss, sample_size, logging_output + + @staticmethod + def aggregate_logging_outputs(logging_outputs): + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + agg_output = { + "loss": loss_sum / nsentences, + "ntokens": ntokens, + "nsentences": nsentences, + "sample_size": sample_size, + } + return agg_output diff --git a/fairseq/examples/speech_recognition/criterions/__init__.py b/fairseq/examples/speech_recognition/criterions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..579abd2ace1b14b80f5e53e5c96583e4d5b14c52 --- /dev/null +++ b/fairseq/examples/speech_recognition/criterions/__init__.py @@ -0,0 +1,17 @@ +import importlib +import os + + +# ASG loss requires flashlight bindings +files_to_skip = set() +try: + import flashlight.lib.sequence.criterion +except ImportError: + files_to_skip.add("ASG_loss.py") + +for file in sorted(os.listdir(os.path.dirname(__file__))): + if file.endswith(".py") and not file.startswith("_") and file not in files_to_skip: + criterion_name = file[: file.find(".py")] + importlib.import_module( + "examples.speech_recognition.criterions." + criterion_name + ) diff --git a/fairseq/examples/speech_recognition/criterions/cross_entropy_acc.py b/fairseq/examples/speech_recognition/criterions/cross_entropy_acc.py new file mode 100644 index 0000000000000000000000000000000000000000..7c4d8ba3802a2da9467c42b0aa18653c7bbb2ec9 --- /dev/null +++ b/fairseq/examples/speech_recognition/criterions/cross_entropy_acc.py @@ -0,0 +1,130 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import logging +import math + +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.criterions import FairseqCriterion, register_criterion + + +@register_criterion("cross_entropy_acc") +class CrossEntropyWithAccCriterion(FairseqCriterion): + def __init__(self, task, sentence_avg): + super().__init__(task) + self.sentence_avg = sentence_avg + + def compute_loss(self, model, net_output, target, reduction, log_probs): + # N, T -> N * T + target = target.view(-1) + lprobs = model.get_normalized_probs(net_output, log_probs=log_probs) + if not hasattr(lprobs, "batch_first"): + logging.warning( + "ERROR: we need to know whether " + "batch first for the net output; " + "you need to set batch_first attribute for the return value of " + "model.get_normalized_probs. Now, we assume this is true, but " + "in the future, we will raise exception instead. " + ) + batch_first = getattr(lprobs, "batch_first", True) + if not batch_first: + lprobs = lprobs.transpose(0, 1) + + # N, T, D -> N * T, D + lprobs = lprobs.view(-1, lprobs.size(-1)) + loss = F.nll_loss( + lprobs, target, ignore_index=self.padding_idx, reduction=reduction + ) + return lprobs, loss + + def get_logging_output(self, sample, target, lprobs, loss): + target = target.view(-1) + mask = target != self.padding_idx + correct = torch.sum( + lprobs.argmax(1).masked_select(mask) == target.masked_select(mask) + ) + total = torch.sum(mask) + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + + logging_output = { + "loss": utils.item(loss.data), # * sample['ntokens'], + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + "correct": utils.item(correct.data), + "total": utils.item(total.data), + "nframes": torch.sum(sample["net_input"]["src_lengths"]).item(), + } + + return sample_size, logging_output + + def forward(self, model, sample, reduction="sum", log_probs=True): + """Computes the cross entropy with accuracy metric for the given sample. + + This is similar to CrossEntropyCriterion in fairseq, but also + computes accuracy metrics as part of logging + + Args: + logprobs (Torch.tensor) of shape N, T, D i.e. + batchsize, timesteps, dimensions + targets (Torch.tensor) of shape N, T i.e batchsize, timesteps + + Returns: + tuple: With three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + + TODO: + * Currently this Criterion will only work with LSTMEncoderModels or + FairseqModels which have decoder, or Models which return TorchTensor + as net_output. + We need to make a change to support all FairseqEncoder models. + """ + net_output = model(**sample["net_input"]) + target = model.get_targets(sample, net_output) + lprobs, loss = self.compute_loss( + model, net_output, target, reduction, log_probs + ) + sample_size, logging_output = self.get_logging_output( + sample, target, lprobs, loss + ) + return loss, sample_size, logging_output + + @staticmethod + def aggregate_logging_outputs(logging_outputs): + """Aggregate logging outputs from data parallel training.""" + correct_sum = sum(log.get("correct", 0) for log in logging_outputs) + total_sum = sum(log.get("total", 0) for log in logging_outputs) + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + nframes = sum(log.get("nframes", 0) for log in logging_outputs) + agg_output = { + "loss": loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0, + # if args.sentence_avg, then sample_size is nsentences, then loss + # is per-sentence loss; else sample_size is ntokens, the loss + # becomes per-output token loss + "ntokens": ntokens, + "nsentences": nsentences, + "nframes": nframes, + "sample_size": sample_size, + "acc": correct_sum * 100.0 / total_sum if total_sum > 0 else 0.0, + "correct": correct_sum, + "total": total_sum, + # total is the number of validate tokens + } + if sample_size != ntokens: + agg_output["nll_loss"] = loss_sum / ntokens / math.log(2) + # loss: per output token loss + # nll_loss: per sentence loss + return agg_output diff --git a/fairseq/examples/speech_recognition/data/__init__.py b/fairseq/examples/speech_recognition/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..47bb6e24ddf25aa4fd5bf0fe9672f89099efb9b4 --- /dev/null +++ b/fairseq/examples/speech_recognition/data/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .asr_dataset import AsrDataset + + +__all__ = [ + "AsrDataset", +] diff --git a/fairseq/examples/speech_recognition/data/asr_dataset.py b/fairseq/examples/speech_recognition/data/asr_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..63a6fcac85d73b1fce8e4d044b4209b1b67fa8ce --- /dev/null +++ b/fairseq/examples/speech_recognition/data/asr_dataset.py @@ -0,0 +1,122 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os + +import numpy as np +from fairseq.data import FairseqDataset + +from . import data_utils +from .collaters import Seq2SeqCollater + + +class AsrDataset(FairseqDataset): + """ + A dataset representing speech and corresponding transcription. + + Args: + aud_paths: (List[str]): A list of str with paths to audio files. + aud_durations_ms (List[int]): A list of int containing the durations of + audio files. + tgt (List[torch.LongTensor]): A list of LongTensors containing the indices + of target transcriptions. + tgt_dict (~fairseq.data.Dictionary): target vocabulary. + ids (List[str]): A list of utterance IDs. + speakers (List[str]): A list of speakers corresponding to utterances. + num_mel_bins (int): Number of triangular mel-frequency bins (default: 80) + frame_length (float): Frame length in milliseconds (default: 25.0) + frame_shift (float): Frame shift in milliseconds (default: 10.0) + """ + + def __init__( + self, + aud_paths, + aud_durations_ms, + tgt, + tgt_dict, + ids, + speakers, + num_mel_bins=80, + frame_length=25.0, + frame_shift=10.0, + ): + assert frame_length > 0 + assert frame_shift > 0 + assert all(x > frame_length for x in aud_durations_ms) + self.frame_sizes = [ + int(1 + (d - frame_length) / frame_shift) for d in aud_durations_ms + ] + + assert len(aud_paths) > 0 + assert len(aud_paths) == len(aud_durations_ms) + assert len(aud_paths) == len(tgt) + assert len(aud_paths) == len(ids) + assert len(aud_paths) == len(speakers) + self.aud_paths = aud_paths + self.tgt_dict = tgt_dict + self.tgt = tgt + self.ids = ids + self.speakers = speakers + self.num_mel_bins = num_mel_bins + self.frame_length = frame_length + self.frame_shift = frame_shift + + self.s2s_collater = Seq2SeqCollater( + 0, + 1, + pad_index=self.tgt_dict.pad(), + eos_index=self.tgt_dict.eos(), + move_eos_to_beginning=True, + ) + + def __getitem__(self, index): + import torchaudio + import torchaudio.compliance.kaldi as kaldi + + tgt_item = self.tgt[index] if self.tgt is not None else None + + path = self.aud_paths[index] + if not os.path.exists(path): + raise FileNotFoundError("Audio file not found: {}".format(path)) + sound, sample_rate = torchaudio.load_wav(path) + output = kaldi.fbank( + sound, + num_mel_bins=self.num_mel_bins, + frame_length=self.frame_length, + frame_shift=self.frame_shift, + ) + output_cmvn = data_utils.apply_mv_norm(output) + + return {"id": index, "data": [output_cmvn.detach(), tgt_item]} + + def __len__(self): + return len(self.aud_paths) + + def collater(self, samples): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[int]): sample indices to collate + + Returns: + dict: a mini-batch suitable for forwarding with a Model + """ + return self.s2s_collater.collate(samples) + + def num_tokens(self, index): + return self.frame_sizes[index] + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return ( + self.frame_sizes[index], + len(self.tgt[index]) if self.tgt is not None else 0, + ) + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + return np.arange(len(self)) diff --git a/fairseq/examples/speech_recognition/data/collaters.py b/fairseq/examples/speech_recognition/data/collaters.py new file mode 100644 index 0000000000000000000000000000000000000000..6acfec876b87e5a00bc92083b1181301a2a18e3f --- /dev/null +++ b/fairseq/examples/speech_recognition/data/collaters.py @@ -0,0 +1,131 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" + This module contains collection of classes which implement + collate functionalities for various tasks. + + Collaters should know what data to expect for each sample + and they should pack / collate them into batches +""" + + +from __future__ import absolute_import, division, print_function, unicode_literals + +import numpy as np +import torch +from fairseq.data import data_utils as fairseq_data_utils + + +class Seq2SeqCollater(object): + """ + Implements collate function mainly for seq2seq tasks + This expects each sample to contain feature (src_tokens) and + targets. + This collator is also used for aligned training task. + """ + + def __init__( + self, + feature_index=0, + label_index=1, + pad_index=1, + eos_index=2, + move_eos_to_beginning=True, + ): + self.feature_index = feature_index + self.label_index = label_index + self.pad_index = pad_index + self.eos_index = eos_index + self.move_eos_to_beginning = move_eos_to_beginning + + def _collate_frames(self, frames): + """Convert a list of 2d frames into a padded 3d tensor + Args: + frames (list): list of 2d frames of size L[i]*f_dim. Where L[i] is + length of i-th frame and f_dim is static dimension of features + Returns: + 3d tensor of size len(frames)*len_max*f_dim where len_max is max of L[i] + """ + len_max = max(frame.size(0) for frame in frames) + f_dim = frames[0].size(1) + res = frames[0].new(len(frames), len_max, f_dim).fill_(0.0) + + for i, v in enumerate(frames): + res[i, : v.size(0)] = v + + return res + + def collate(self, samples): + """ + utility function to collate samples into batch for speech recognition. + """ + if len(samples) == 0: + return {} + + # parse samples into torch tensors + parsed_samples = [] + for s in samples: + # skip invalid samples + if s["data"][self.feature_index] is None: + continue + source = s["data"][self.feature_index] + if isinstance(source, (np.ndarray, np.generic)): + source = torch.from_numpy(source) + target = s["data"][self.label_index] + if isinstance(target, (np.ndarray, np.generic)): + target = torch.from_numpy(target).long() + elif isinstance(target, list): + target = torch.LongTensor(target) + + parsed_sample = {"id": s["id"], "source": source, "target": target} + parsed_samples.append(parsed_sample) + samples = parsed_samples + + id = torch.LongTensor([s["id"] for s in samples]) + frames = self._collate_frames([s["source"] for s in samples]) + # sort samples by descending number of frames + frames_lengths = torch.LongTensor([s["source"].size(0) for s in samples]) + frames_lengths, sort_order = frames_lengths.sort(descending=True) + id = id.index_select(0, sort_order) + frames = frames.index_select(0, sort_order) + + target = None + target_lengths = None + prev_output_tokens = None + if samples[0].get("target", None) is not None: + ntokens = sum(len(s["target"]) for s in samples) + target = fairseq_data_utils.collate_tokens( + [s["target"] for s in samples], + self.pad_index, + self.eos_index, + left_pad=False, + move_eos_to_beginning=False, + ) + target = target.index_select(0, sort_order) + target_lengths = torch.LongTensor( + [s["target"].size(0) for s in samples] + ).index_select(0, sort_order) + prev_output_tokens = fairseq_data_utils.collate_tokens( + [s["target"] for s in samples], + self.pad_index, + self.eos_index, + left_pad=False, + move_eos_to_beginning=self.move_eos_to_beginning, + ) + prev_output_tokens = prev_output_tokens.index_select(0, sort_order) + else: + ntokens = sum(len(s["source"]) for s in samples) + + batch = { + "id": id, + "ntokens": ntokens, + "net_input": {"src_tokens": frames, "src_lengths": frames_lengths}, + "target": target, + "target_lengths": target_lengths, + "nsentences": len(samples), + } + if prev_output_tokens is not None: + batch["net_input"]["prev_output_tokens"] = prev_output_tokens + return batch diff --git a/fairseq/examples/speech_recognition/data/data_utils.py b/fairseq/examples/speech_recognition/data/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cc4729e63c8ef551b29617d1169a44c24f509ad0 --- /dev/null +++ b/fairseq/examples/speech_recognition/data/data_utils.py @@ -0,0 +1,100 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def calc_mean_invstddev(feature): + if len(feature.size()) != 2: + raise ValueError("We expect the input feature to be 2-D tensor") + mean = feature.mean(0) + var = feature.var(0) + # avoid division by ~zero + eps = 1e-8 + if (var < eps).any(): + return mean, 1.0 / (torch.sqrt(var) + eps) + return mean, 1.0 / torch.sqrt(var) + + +def apply_mv_norm(features): + # If there is less than 2 spectrograms, the variance cannot be computed (is NaN) + # and normalization is not possible, so return the item as it is + if features.size(0) < 2: + return features + mean, invstddev = calc_mean_invstddev(features) + res = (features - mean) * invstddev + return res + + +def lengths_to_encoder_padding_mask(lengths, batch_first=False): + """ + convert lengths (a 1-D Long/Int tensor) to 2-D binary tensor + + Args: + lengths: a (B, )-shaped tensor + + Return: + max_length: maximum length of B sequences + encoder_padding_mask: a (max_length, B) binary mask, where + [t, b] = 0 for t < lengths[b] and 1 otherwise + + TODO: + kernelize this function if benchmarking shows this function is slow + """ + max_lengths = torch.max(lengths).item() + bsz = lengths.size(0) + encoder_padding_mask = torch.arange( + max_lengths + ).to( # a (T, ) tensor with [0, ..., T-1] + lengths.device + ).view( # move to the right device + 1, max_lengths + ).expand( # reshape to (1, T)-shaped tensor + bsz, -1 + ) >= lengths.view( # expand to (B, T)-shaped tensor + bsz, 1 + ).expand( + -1, max_lengths + ) + if not batch_first: + return encoder_padding_mask.t(), max_lengths + else: + return encoder_padding_mask, max_lengths + + +def encoder_padding_mask_to_lengths( + encoder_padding_mask, max_lengths, batch_size, device +): + """ + convert encoder_padding_mask (2-D binary tensor) to a 1-D tensor + + Conventionally, encoder output contains a encoder_padding_mask, which is + a 2-D mask in a shape (T, B), whose (t, b) element indicate whether + encoder_out[t, b] is a valid output (=0) or not (=1). Occasionally, we + need to convert this mask tensor to a 1-D tensor in shape (B, ), where + [b] denotes the valid length of b-th sequence + + Args: + encoder_padding_mask: a (T, B)-shaped binary tensor or None; if None, + indicating all are valid + Return: + seq_lengths: a (B,)-shaped tensor, where its (b, )-th element is the + number of valid elements of b-th sequence + + max_lengths: maximum length of all sequence, if encoder_padding_mask is + not None, max_lengths must equal to encoder_padding_mask.size(0) + + batch_size: batch size; if encoder_padding_mask is + not None, max_lengths must equal to encoder_padding_mask.size(1) + + device: which device to put the result on + """ + if encoder_padding_mask is None: + return torch.Tensor([max_lengths] * batch_size).to(torch.int32).to(device) + + assert encoder_padding_mask.size(0) == max_lengths, "max_lengths does not match" + assert encoder_padding_mask.size(1) == batch_size, "batch_size does not match" + + return max_lengths - torch.sum(encoder_padding_mask, dim=0) diff --git a/fairseq/examples/speech_recognition/data/replabels.py b/fairseq/examples/speech_recognition/data/replabels.py new file mode 100644 index 0000000000000000000000000000000000000000..441f1bd432b95865fc981c6c695cee299b07ed62 --- /dev/null +++ b/fairseq/examples/speech_recognition/data/replabels.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Replabel transforms for use with flashlight's ASG criterion. +""" + + +def replabel_symbol(i): + """ + Replabel symbols used in flashlight, currently just "1", "2", ... + This prevents training with numeral tokens, so this might change in the future + """ + return str(i) + + +def pack_replabels(tokens, dictionary, max_reps): + """ + Pack a token sequence so that repeated symbols are replaced by replabels + """ + if len(tokens) == 0 or max_reps <= 0: + return tokens + + replabel_value_to_idx = [0] * (max_reps + 1) + for i in range(1, max_reps + 1): + replabel_value_to_idx[i] = dictionary.index(replabel_symbol(i)) + + result = [] + prev_token = -1 + num_reps = 0 + for token in tokens: + if token == prev_token and num_reps < max_reps: + num_reps += 1 + else: + if num_reps > 0: + result.append(replabel_value_to_idx[num_reps]) + num_reps = 0 + result.append(token) + prev_token = token + if num_reps > 0: + result.append(replabel_value_to_idx[num_reps]) + return result + + +def unpack_replabels(tokens, dictionary, max_reps): + """ + Unpack a token sequence so that replabels are replaced by repeated symbols + """ + if len(tokens) == 0 or max_reps <= 0: + return tokens + + replabel_idx_to_value = {} + for i in range(1, max_reps + 1): + replabel_idx_to_value[dictionary.index(replabel_symbol(i))] = i + + result = [] + prev_token = -1 + for token in tokens: + try: + for _ in range(replabel_idx_to_value[token]): + result.append(prev_token) + prev_token = -1 + except KeyError: + result.append(token) + prev_token = token + return result diff --git a/fairseq/examples/speech_recognition/datasets/asr_prep_json.py b/fairseq/examples/speech_recognition/datasets/asr_prep_json.py new file mode 100644 index 0000000000000000000000000000000000000000..b8db8ff16691158fae034a8ab3faad622b351caf --- /dev/null +++ b/fairseq/examples/speech_recognition/datasets/asr_prep_json.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import argparse +import concurrent.futures +import json +import multiprocessing +import os +from collections import namedtuple +from itertools import chain + +import sentencepiece as spm +from fairseq.data import Dictionary + + +MILLISECONDS_TO_SECONDS = 0.001 + + +def process_sample(aud_path, lable, utt_id, sp, tgt_dict): + import torchaudio + + input = {} + output = {} + si, ei = torchaudio.info(aud_path) + input["length_ms"] = int( + si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS + ) + input["path"] = aud_path + + token = " ".join(sp.EncodeAsPieces(lable)) + ids = tgt_dict.encode_line(token, append_eos=False) + output["text"] = lable + output["token"] = token + output["tokenid"] = ", ".join(map(str, [t.tolist() for t in ids])) + return {utt_id: {"input": input, "output": output}} + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--audio-dirs", + nargs="+", + default=["-"], + required=True, + help="input directories with audio files", + ) + parser.add_argument( + "--labels", + required=True, + help="aggregated input labels with format per line", + type=argparse.FileType("r", encoding="UTF-8"), + ) + parser.add_argument( + "--spm-model", + required=True, + help="sentencepiece model to use for encoding", + type=argparse.FileType("r", encoding="UTF-8"), + ) + parser.add_argument( + "--dictionary", + required=True, + help="file to load fairseq dictionary from", + type=argparse.FileType("r", encoding="UTF-8"), + ) + parser.add_argument("--audio-format", choices=["flac", "wav"], default="wav") + parser.add_argument( + "--output", + required=True, + type=argparse.FileType("w"), + help="path to save json output", + ) + args = parser.parse_args() + + sp = spm.SentencePieceProcessor() + sp.Load(args.spm_model.name) + + tgt_dict = Dictionary.load(args.dictionary) + + labels = {} + for line in args.labels: + (utt_id, label) = line.split(" ", 1) + labels[utt_id] = label + if len(labels) == 0: + raise Exception("No labels found in ", args.labels_path) + + Sample = namedtuple("Sample", "aud_path utt_id") + samples = [] + for path, _, files in chain.from_iterable( + os.walk(path) for path in args.audio_dirs + ): + for f in files: + if f.endswith(args.audio_format): + if len(os.path.splitext(f)) != 2: + raise Exception("Expect file name. Got: ", f) + utt_id = os.path.splitext(f)[0] + if utt_id not in labels: + continue + samples.append(Sample(os.path.join(path, f), utt_id)) + + utts = {} + num_cpu = multiprocessing.cpu_count() + with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpu) as executor: + future_to_sample = { + executor.submit( + process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict + ): s + for s in samples + } + for future in concurrent.futures.as_completed(future_to_sample): + try: + data = future.result() + except Exception as exc: + print("generated an exception: ", exc) + else: + utts.update(data) + json.dump({"utts": utts}, args.output, indent=4) + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/speech_recognition/datasets/prepare-librispeech.sh b/fairseq/examples/speech_recognition/datasets/prepare-librispeech.sh new file mode 100644 index 0000000000000000000000000000000000000000..9e9297f08947027685ff508bfa91ff26b0d8ea0c --- /dev/null +++ b/fairseq/examples/speech_recognition/datasets/prepare-librispeech.sh @@ -0,0 +1,88 @@ +#!/usr/bin/env bash +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# Prepare librispeech dataset + +base_url=www.openslr.org/resources/12 +train_dir=train_960 + +if [ "$#" -ne 2 ]; then + echo "Usage: $0 " + echo "e.g.: $0 /tmp/librispeech_raw/ ~/data/librispeech_final" + exit 1 +fi + +download_dir=${1%/} +out_dir=${2%/} + +fairseq_root=~/fairseq-py/ +mkdir -p ${out_dir} +cd ${out_dir} || exit + +nbpe=5000 +bpemode=unigram + +if [ ! -d "$fairseq_root" ]; then + echo "$0: Please set correct fairseq_root" + exit 1 +fi + +echo "Data Download" +for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do + url=$base_url/$part.tar.gz + if ! wget -P $download_dir $url; then + echo "$0: wget failed for $url" + exit 1 + fi + if ! tar -C $download_dir -xvzf $download_dir/$part.tar.gz; then + echo "$0: error un-tarring archive $download_dir/$part.tar.gz" + exit 1 + fi +done + +echo "Merge all train packs into one" +mkdir -p ${download_dir}/LibriSpeech/${train_dir}/ +for part in train-clean-100 train-clean-360 train-other-500; do + mv ${download_dir}/LibriSpeech/${part}/* $download_dir/LibriSpeech/${train_dir}/ +done +echo "Merge train text" +find ${download_dir}/LibriSpeech/${train_dir}/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/${train_dir}/text + +# Use combined dev-clean and dev-other as validation set +find ${download_dir}/LibriSpeech/dev-clean/ ${download_dir}/LibriSpeech/dev-other/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/valid_text +find ${download_dir}/LibriSpeech/test-clean/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/test-clean/text +find ${download_dir}/LibriSpeech/test-other/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/test-other/text + + +dict=data/lang_char/${train_dir}_${bpemode}${nbpe}_units.txt +encoded=data/lang_char/${train_dir}_${bpemode}${nbpe}_encoded.txt +fairseq_dict=data/lang_char/${train_dir}_${bpemode}${nbpe}_fairseq_dict.txt +bpemodel=data/lang_char/${train_dir}_${bpemode}${nbpe} +echo "dictionary: ${dict}" +echo "Dictionary preparation" +mkdir -p data/lang_char/ +echo " 3" > ${dict} +echo " 2" >> ${dict} +echo " 1" >> ${dict} +cut -f 2- -d" " ${download_dir}/LibriSpeech/${train_dir}/text > data/lang_char/input.txt +spm_train --input=data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000 --unk_id=3 --eos_id=2 --pad_id=1 --bos_id=-1 --character_coverage=1 +spm_encode --model=${bpemodel}.model --output_format=piece < data/lang_char/input.txt > ${encoded} +cat ${encoded} | tr ' ' '\n' | sort | uniq | awk '{print $0 " " NR+3}' >> ${dict} +cat ${encoded} | tr ' ' '\n' | sort | uniq -c | awk '{print $2 " " $1}' > ${fairseq_dict} +wc -l ${dict} + +echo "Prepare train and test jsons" +for part in train_960 test-other test-clean; do + python ${fairseq_root}/examples/speech_recognition/datasets/asr_prep_json.py --audio-dirs ${download_dir}/LibriSpeech/${part} --labels ${download_dir}/LibriSpeech/${part}/text --spm-model ${bpemodel}.model --audio-format flac --dictionary ${fairseq_dict} --output ${part}.json +done +# fairseq expects to find train.json and valid.json during training +mv train_960.json train.json + +echo "Prepare valid json" +python ${fairseq_root}/examples/speech_recognition/datasets/asr_prep_json.py --audio-dirs ${download_dir}/LibriSpeech/dev-clean ${download_dir}/LibriSpeech/dev-other --labels ${download_dir}/LibriSpeech/valid_text --spm-model ${bpemodel}.model --audio-format flac --dictionary ${fairseq_dict} --output valid.json + +cp ${fairseq_dict} ./dict.txt +cp ${bpemodel}.model ./spm.model diff --git a/fairseq/examples/speech_recognition/infer.py b/fairseq/examples/speech_recognition/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..ce16bf47cfb63c9de787c36d4a5edd24cde6e421 --- /dev/null +++ b/fairseq/examples/speech_recognition/infer.py @@ -0,0 +1,436 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Run inference for pre-processed data with a trained model. +""" + +import ast +import logging +import math +import os +import sys + +import editdistance +import numpy as np +import torch +from fairseq import checkpoint_utils, options, progress_bar, tasks, utils +from fairseq.data.data_utils import post_process +from fairseq.logging.meters import StopwatchMeter, TimeMeter + + +logging.basicConfig() +logging.root.setLevel(logging.INFO) +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def add_asr_eval_argument(parser): + parser.add_argument("--kspmodel", default=None, help="sentence piece model") + parser.add_argument( + "--wfstlm", default=None, help="wfstlm on dictonary output units" + ) + parser.add_argument( + "--rnnt_decoding_type", + default="greedy", + help="wfstlm on dictonary\ +output units", + ) + try: + parser.add_argument( + "--lm-weight", + "--lm_weight", + type=float, + default=0.2, + help="weight for lm while interpolating with neural score", + ) + except: + pass + parser.add_argument( + "--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level" + ) + parser.add_argument( + "--w2l-decoder", + choices=["viterbi", "kenlm", "fairseqlm"], + help="use a w2l decoder", + ) + parser.add_argument("--lexicon", help="lexicon for w2l decoder") + parser.add_argument("--unit-lm", action="store_true", help="if using a unit lm") + parser.add_argument("--kenlm-model", "--lm-model", help="lm model for w2l decoder") + parser.add_argument("--beam-threshold", type=float, default=25.0) + parser.add_argument("--beam-size-token", type=float, default=100) + parser.add_argument("--word-score", type=float, default=1.0) + parser.add_argument("--unk-weight", type=float, default=-math.inf) + parser.add_argument("--sil-weight", type=float, default=0.0) + parser.add_argument( + "--dump-emissions", + type=str, + default=None, + help="if present, dumps emissions into this file and exits", + ) + parser.add_argument( + "--dump-features", + type=str, + default=None, + help="if present, dumps features into this file and exits", + ) + parser.add_argument( + "--load-emissions", + type=str, + default=None, + help="if present, loads emissions from this file", + ) + return parser + + +def check_args(args): + # assert args.path is not None, "--path required for generation!" + # assert args.results_path is not None, "--results_path required for generation!" + assert ( + not args.sampling or args.nbest == args.beam + ), "--sampling requires --nbest to be equal to --beam" + assert ( + args.replace_unk is None or args.raw_text + ), "--replace-unk requires a raw text dataset (--raw-text)" + + +def get_dataset_itr(args, task, models): + return task.get_batch_iterator( + dataset=task.dataset(args.gen_subset), + max_tokens=args.max_tokens, + max_sentences=args.batch_size, + max_positions=(sys.maxsize, sys.maxsize), + ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=args.required_batch_size_multiple, + num_shards=args.num_shards, + shard_id=args.shard_id, + num_workers=args.num_workers, + data_buffer_size=args.data_buffer_size, + ).next_epoch_itr(shuffle=False) + + +def process_predictions( + args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id +): + for hypo in hypos[: min(len(hypos), args.nbest)]: + hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu()) + + if "words" in hypo: + hyp_words = " ".join(hypo["words"]) + else: + hyp_words = post_process(hyp_pieces, args.post_process) + + if res_files is not None: + print( + "{} ({}-{})".format(hyp_pieces, speaker, id), + file=res_files["hypo.units"], + ) + print( + "{} ({}-{})".format(hyp_words, speaker, id), + file=res_files["hypo.words"], + ) + + tgt_pieces = tgt_dict.string(target_tokens) + tgt_words = post_process(tgt_pieces, args.post_process) + + if res_files is not None: + print( + "{} ({}-{})".format(tgt_pieces, speaker, id), + file=res_files["ref.units"], + ) + print( + "{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"] + ) + + if not args.quiet: + logger.info("HYPO:" + hyp_words) + logger.info("TARGET:" + tgt_words) + logger.info("___________________") + + hyp_words = hyp_words.split() + tgt_words = tgt_words.split() + return editdistance.eval(hyp_words, tgt_words), len(tgt_words) + + +def prepare_result_files(args): + def get_res_file(file_prefix): + if args.num_shards > 1: + file_prefix = f"{args.shard_id}_{file_prefix}" + path = os.path.join( + args.results_path, + "{}-{}-{}.txt".format( + file_prefix, os.path.basename(args.path), args.gen_subset + ), + ) + return open(path, "w", buffering=1) + + if not args.results_path: + return None + + return { + "hypo.words": get_res_file("hypo.word"), + "hypo.units": get_res_file("hypo.units"), + "ref.words": get_res_file("ref.word"), + "ref.units": get_res_file("ref.units"), + } + + +def optimize_models(args, use_cuda, models): + """Optimize ensemble for generation""" + for model in models: + model.make_generation_fast_( + beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, + need_attn=args.print_alignment, + ) + if args.fp16: + model.half() + if use_cuda: + model.cuda() + + +def apply_half(t): + if t.dtype is torch.float32: + return t.to(dtype=torch.half) + return t + + +class ExistingEmissionsDecoder(object): + def __init__(self, decoder, emissions): + self.decoder = decoder + self.emissions = emissions + + def generate(self, models, sample, **unused): + ids = sample["id"].cpu().numpy() + try: + emissions = np.stack(self.emissions[ids]) + except: + print([x.shape for x in self.emissions[ids]]) + raise Exception("invalid sizes") + emissions = torch.from_numpy(emissions) + return self.decoder.decode(emissions) + + +def main(args, task=None, model_state=None): + check_args(args) + + use_fp16 = args.fp16 + if args.max_tokens is None and args.batch_size is None: + args.max_tokens = 4000000 + logger.info(args) + + use_cuda = torch.cuda.is_available() and not args.cpu + + logger.info("| decoding with criterion {}".format(args.criterion)) + + task = tasks.setup_task(args) + + # Load ensemble + if args.load_emissions: + models, criterions = [], [] + task.load_dataset(args.gen_subset) + else: + logger.info("| loading model(s) from {}".format(args.path)) + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( + utils.split_paths(args.path, separator="\\"), + arg_overrides=ast.literal_eval(args.model_overrides), + task=task, + suffix=args.checkpoint_suffix, + strict=(args.checkpoint_shard_count == 1), + num_shards=args.checkpoint_shard_count, + state=model_state, + ) + optimize_models(args, use_cuda, models) + task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task) + + + # Set dictionary + tgt_dict = task.target_dictionary + + logger.info( + "| {} {} {} examples".format( + args.data, args.gen_subset, len(task.dataset(args.gen_subset)) + ) + ) + + # hack to pass transitions to W2lDecoder + if args.criterion == "asg_loss": + raise NotImplementedError("asg_loss is currently not supported") + # trans = criterions[0].asg.trans.data + # args.asg_transitions = torch.flatten(trans).tolist() + + # Load dataset (possibly sharded) + itr = get_dataset_itr(args, task, models) + + # Initialize generator + gen_timer = StopwatchMeter() + + def build_generator(args): + w2l_decoder = getattr(args, "w2l_decoder", None) + if w2l_decoder == "viterbi": + from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder + + return W2lViterbiDecoder(args, task.target_dictionary) + elif w2l_decoder == "kenlm": + from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder + + return W2lKenLMDecoder(args, task.target_dictionary) + elif w2l_decoder == "fairseqlm": + from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder + + return W2lFairseqLMDecoder(args, task.target_dictionary) + else: + print( + "only flashlight decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment" + ) + + # please do not touch this unless you test both generate.py and infer.py with audio_pretraining task + generator = build_generator(args) + + if args.load_emissions: + generator = ExistingEmissionsDecoder( + generator, np.load(args.load_emissions, allow_pickle=True) + ) + logger.info("loaded emissions from " + args.load_emissions) + + num_sentences = 0 + + if args.results_path is not None and not os.path.exists(args.results_path): + os.makedirs(args.results_path) + + max_source_pos = ( + utils.resolve_max_positions( + task.max_positions(), *[model.max_positions() for model in models] + ), + ) + + if max_source_pos is not None: + max_source_pos = max_source_pos[0] + if max_source_pos is not None: + max_source_pos = max_source_pos[0] - 1 + + if args.dump_emissions: + emissions = {} + if args.dump_features: + features = {} + models[0].bert.proj = None + else: + res_files = prepare_result_files(args) + errs_t = 0 + lengths_t = 0 + with progress_bar.build_progress_bar(args, itr) as t: + wps_meter = TimeMeter() + for sample in t: + sample = utils.move_to_cuda(sample) if use_cuda else sample + if use_fp16: + sample = utils.apply_to_sample(apply_half, sample) + if "net_input" not in sample: + continue + + prefix_tokens = None + if args.prefix_size > 0: + prefix_tokens = sample["target"][:, : args.prefix_size] + + gen_timer.start() + if args.dump_emissions: + with torch.no_grad(): + encoder_out = models[0](**sample["net_input"]) + emm = models[0].get_normalized_probs(encoder_out, log_probs=True) + emm = emm.transpose(0, 1).cpu().numpy() + for i, id in enumerate(sample["id"]): + emissions[id.item()] = emm[i] + continue + elif args.dump_features: + with torch.no_grad(): + encoder_out = models[0](**sample["net_input"]) + feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy() + for i, id in enumerate(sample["id"]): + padding = ( + encoder_out["encoder_padding_mask"][i].cpu().numpy() + if encoder_out["encoder_padding_mask"] is not None + else None + ) + features[id.item()] = (feat[i], padding) + continue + hypos = task.inference_step(generator, models, sample, prefix_tokens) + num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) + gen_timer.stop(num_generated_tokens) + + for i, sample_id in enumerate(sample["id"].tolist()): + speaker = None + # id = task.dataset(args.gen_subset).ids[int(sample_id)] + id = sample_id + toks = ( + sample["target"][i, :] + if "target_label" not in sample + else sample["target_label"][i, :] + ) + target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu() + # Process top predictions + errs, length = process_predictions( + args, + hypos[i], + None, + tgt_dict, + target_tokens, + res_files, + speaker, + id, + ) + errs_t += errs + lengths_t += length + + wps_meter.update(num_generated_tokens) + t.log({"wps": round(wps_meter.avg)}) + num_sentences += ( + sample["nsentences"] if "nsentences" in sample else sample["id"].numel() + ) + + wer = None + if args.dump_emissions: + emm_arr = [] + for i in range(len(emissions)): + emm_arr.append(emissions[i]) + np.save(args.dump_emissions, emm_arr) + logger.info(f"saved {len(emissions)} emissions to {args.dump_emissions}") + elif args.dump_features: + feat_arr = [] + for i in range(len(features)): + feat_arr.append(features[i]) + np.save(args.dump_features, feat_arr) + logger.info(f"saved {len(features)} emissions to {args.dump_features}") + else: + if lengths_t > 0: + wer = errs_t * 100.0 / lengths_t + logger.info(f"WER: {wer}") + + logger.info( + "| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}" + "sentences/s, {:.2f} tokens/s)".format( + num_sentences, + gen_timer.n, + gen_timer.sum, + num_sentences / gen_timer.sum, + 1.0 / gen_timer.avg, + ) + ) + logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam)) + return task, wer + + +def make_parser(): + parser = options.get_generation_parser() + parser = add_asr_eval_argument(parser) + return parser + + +def cli_main(): + parser = make_parser() + args = options.parse_args_and_arch(parser) + main(args) + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq/examples/speech_recognition/kaldi/__init__.py b/fairseq/examples/speech_recognition/kaldi/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fairseq/examples/speech_recognition/kaldi/add-self-loop-simple.cc b/fairseq/examples/speech_recognition/kaldi/add-self-loop-simple.cc new file mode 100644 index 0000000000000000000000000000000000000000..e18fb62df52ab85d7802615d8619b0fd94a08f8c --- /dev/null +++ b/fairseq/examples/speech_recognition/kaldi/add-self-loop-simple.cc @@ -0,0 +1,94 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include "fstext/fstext-lib.h" // @manual +#include "util/common-utils.h" // @manual + +/* + * This program is to modify a FST without self-loop by: + * for each incoming arc with non-eps input symbol, add a self-loop arc + * with that non-eps symbol as input and eps as output. + * + * This is to make sure the resultant FST can do deduplication for repeated + * symbols, which is very common in acoustic model + * + */ +namespace { +int32 AddSelfLoopsSimple(fst::StdVectorFst* fst) { + typedef fst::MutableArcIterator IterType; + + int32 num_states_before = fst->NumStates(); + fst::MakePrecedingInputSymbolsSame(false, fst); + int32 num_states_after = fst->NumStates(); + KALDI_LOG << "There are " << num_states_before + << " states in the original FST; " + << " after MakePrecedingInputSymbolsSame, there are " + << num_states_after << " states " << std::endl; + + auto weight_one = fst::StdArc::Weight::One(); + + int32 num_arc_added = 0; + + fst::StdArc self_loop_arc; + self_loop_arc.weight = weight_one; + + int32 num_states = fst->NumStates(); + std::vector> incoming_non_eps_label_per_state(num_states); + + for (int32 state = 0; state < num_states; state++) { + for (IterType aiter(fst, state); !aiter.Done(); aiter.Next()) { + fst::StdArc arc(aiter.Value()); + if (arc.ilabel != 0) { + incoming_non_eps_label_per_state[arc.nextstate].insert(arc.ilabel); + } + } + } + + for (int32 state = 0; state < num_states; state++) { + if (!incoming_non_eps_label_per_state[state].empty()) { + auto& ilabel_set = incoming_non_eps_label_per_state[state]; + for (auto it = ilabel_set.begin(); it != ilabel_set.end(); it++) { + self_loop_arc.ilabel = *it; + self_loop_arc.olabel = 0; + self_loop_arc.nextstate = state; + fst->AddArc(state, self_loop_arc); + num_arc_added++; + } + } + } + return num_arc_added; +} + +void print_usage() { + std::cout << "add-self-loop-simple usage:\n" + "\tadd-self-loop-simple \n"; +} +} // namespace + +int main(int argc, char** argv) { + if (argc != 3) { + print_usage(); + exit(1); + } + + auto input = argv[1]; + auto output = argv[2]; + + auto fst = fst::ReadFstKaldi(input); + auto num_states = fst->NumStates(); + KALDI_LOG << "Loading FST from " << input << " with " << num_states + << " states." << std::endl; + + int32 num_arc_added = AddSelfLoopsSimple(fst); + KALDI_LOG << "Adding " << num_arc_added << " self-loop arcs " << std::endl; + + fst::WriteFstKaldi(*fst, std::string(output)); + KALDI_LOG << "Writing FST to " << output << std::endl; + + delete fst; +} diff --git a/fairseq/examples/speech_recognition/kaldi/config/kaldi_initializer.yaml b/fairseq/examples/speech_recognition/kaldi/config/kaldi_initializer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..be9ba98f55463d41d5d5ea35e306abc0886dbead --- /dev/null +++ b/fairseq/examples/speech_recognition/kaldi/config/kaldi_initializer.yaml @@ -0,0 +1,8 @@ +# @package _group_ + +data_dir: ??? +fst_dir: ??? +in_labels: ??? +kaldi_root: ??? +lm_arpa: ??? +blank_symbol: diff --git a/fairseq/examples/speech_recognition/kaldi/kaldi_decoder.py b/fairseq/examples/speech_recognition/kaldi/kaldi_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..5f62cc58ae8c0c5a3ba7d17713fedf0abc302942 --- /dev/null +++ b/fairseq/examples/speech_recognition/kaldi/kaldi_decoder.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from concurrent.futures import ThreadPoolExecutor +import logging +from omegaconf import MISSING +import os +import torch +from typing import Optional +import warnings + + +from dataclasses import dataclass +from fairseq.dataclass import FairseqDataclass +from .kaldi_initializer import KaldiInitializerConfig, initalize_kaldi + + +logger = logging.getLogger(__name__) + + +@dataclass +class KaldiDecoderConfig(FairseqDataclass): + hlg_graph_path: Optional[str] = None + output_dict: str = MISSING + + kaldi_initializer_config: Optional[KaldiInitializerConfig] = None + + acoustic_scale: float = 0.5 + max_active: int = 10000 + beam_delta: float = 0.5 + hash_ratio: float = 2.0 + + is_lattice: bool = False + lattice_beam: float = 10.0 + prune_interval: int = 25 + determinize_lattice: bool = True + prune_scale: float = 0.1 + max_mem: int = 0 + phone_determinize: bool = True + word_determinize: bool = True + minimize: bool = True + + num_threads: int = 1 + + +class KaldiDecoder(object): + def __init__( + self, + cfg: KaldiDecoderConfig, + beam: int, + nbest: int = 1, + ): + try: + from kaldi.asr import FasterRecognizer, LatticeFasterRecognizer + from kaldi.base import set_verbose_level + from kaldi.decoder import ( + FasterDecoder, + FasterDecoderOptions, + LatticeFasterDecoder, + LatticeFasterDecoderOptions, + ) + from kaldi.lat.functions import DeterminizeLatticePhonePrunedOptions + from kaldi.fstext import read_fst_kaldi, SymbolTable + except: + warnings.warn( + "pykaldi is required for this functionality. Please install from https://github.com/pykaldi/pykaldi" + ) + + # set_verbose_level(2) + + self.acoustic_scale = cfg.acoustic_scale + self.nbest = nbest + + if cfg.hlg_graph_path is None: + assert ( + cfg.kaldi_initializer_config is not None + ), "Must provide hlg graph path or kaldi initializer config" + cfg.hlg_graph_path = initalize_kaldi(cfg.kaldi_initializer_config) + + assert os.path.exists(cfg.hlg_graph_path), cfg.hlg_graph_path + + if cfg.is_lattice: + self.dec_cls = LatticeFasterDecoder + opt_cls = LatticeFasterDecoderOptions + self.rec_cls = LatticeFasterRecognizer + else: + assert self.nbest == 1, "nbest > 1 requires lattice decoder" + self.dec_cls = FasterDecoder + opt_cls = FasterDecoderOptions + self.rec_cls = FasterRecognizer + + self.decoder_options = opt_cls() + self.decoder_options.beam = beam + self.decoder_options.max_active = cfg.max_active + self.decoder_options.beam_delta = cfg.beam_delta + self.decoder_options.hash_ratio = cfg.hash_ratio + + if cfg.is_lattice: + self.decoder_options.lattice_beam = cfg.lattice_beam + self.decoder_options.prune_interval = cfg.prune_interval + self.decoder_options.determinize_lattice = cfg.determinize_lattice + self.decoder_options.prune_scale = cfg.prune_scale + det_opts = DeterminizeLatticePhonePrunedOptions() + det_opts.max_mem = cfg.max_mem + det_opts.phone_determinize = cfg.phone_determinize + det_opts.word_determinize = cfg.word_determinize + det_opts.minimize = cfg.minimize + self.decoder_options.det_opts = det_opts + + self.output_symbols = {} + with open(cfg.output_dict, "r") as f: + for line in f: + items = line.rstrip().split() + assert len(items) == 2 + self.output_symbols[int(items[1])] = items[0] + + logger.info(f"Loading FST from {cfg.hlg_graph_path}") + self.fst = read_fst_kaldi(cfg.hlg_graph_path) + self.symbol_table = SymbolTable.read_text(cfg.output_dict) + + self.executor = ThreadPoolExecutor(max_workers=cfg.num_threads) + + def generate(self, models, sample, **unused): + """Generate a batch of inferences.""" + # model.forward normally channels prev_output_tokens into the decoder + # separately, but SequenceGenerator directly calls model.encoder + encoder_input = { + k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens" + } + emissions, padding = self.get_emissions(models, encoder_input) + return self.decode(emissions, padding) + + def get_emissions(self, models, encoder_input): + """Run encoder and normalize emissions""" + model = models[0] + + all_encoder_out = [m(**encoder_input) for m in models] + + if len(all_encoder_out) > 1: + + if "encoder_out" in all_encoder_out[0]: + encoder_out = { + "encoder_out": sum(e["encoder_out"] for e in all_encoder_out) + / len(all_encoder_out), + "encoder_padding_mask": all_encoder_out[0]["encoder_padding_mask"], + } + padding = encoder_out["encoder_padding_mask"] + else: + encoder_out = { + "logits": sum(e["logits"] for e in all_encoder_out) + / len(all_encoder_out), + "padding_mask": all_encoder_out[0]["padding_mask"], + } + padding = encoder_out["padding_mask"] + else: + encoder_out = all_encoder_out[0] + padding = ( + encoder_out["padding_mask"] + if "padding_mask" in encoder_out + else encoder_out["encoder_padding_mask"] + ) + + if hasattr(model, "get_logits"): + emissions = model.get_logits(encoder_out, normalize=True) + else: + emissions = model.get_normalized_probs(encoder_out, log_probs=True) + + return ( + emissions.cpu().float().transpose(0, 1), + padding.cpu() if padding is not None and padding.any() else None, + ) + + def decode_one(self, logits, padding): + from kaldi.matrix import Matrix + + decoder = self.dec_cls(self.fst, self.decoder_options) + asr = self.rec_cls( + decoder, self.symbol_table, acoustic_scale=self.acoustic_scale + ) + + if padding is not None: + logits = logits[~padding] + + mat = Matrix(logits.numpy()) + + out = asr.decode(mat) + + if self.nbest > 1: + from kaldi.fstext import shortestpath + from kaldi.fstext.utils import ( + convert_compact_lattice_to_lattice, + convert_lattice_to_std, + convert_nbest_to_list, + get_linear_symbol_sequence, + ) + + lat = out["lattice"] + + sp = shortestpath(lat, nshortest=self.nbest) + + sp = convert_compact_lattice_to_lattice(sp) + sp = convert_lattice_to_std(sp) + seq = convert_nbest_to_list(sp) + + results = [] + for s in seq: + _, o, w = get_linear_symbol_sequence(s) + words = list(self.output_symbols[z] for z in o) + results.append( + { + "tokens": words, + "words": words, + "score": w.value, + "emissions": logits, + } + ) + return results + else: + words = out["text"].split() + return [ + { + "tokens": words, + "words": words, + "score": out["likelihood"], + "emissions": logits, + } + ] + + def decode(self, emissions, padding): + if padding is None: + padding = [None] * len(emissions) + + ret = list( + map( + lambda e, p: self.executor.submit(self.decode_one, e, p), + emissions, + padding, + ) + ) + return ret diff --git a/fairseq/examples/speech_recognition/kaldi/kaldi_initializer.py b/fairseq/examples/speech_recognition/kaldi/kaldi_initializer.py new file mode 100644 index 0000000000000000000000000000000000000000..6d2a2a4b6b809ba1106f9a57cb6f241dc083e670 --- /dev/null +++ b/fairseq/examples/speech_recognition/kaldi/kaldi_initializer.py @@ -0,0 +1,698 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +import hydra +from hydra.core.config_store import ConfigStore +import logging +from omegaconf import MISSING, OmegaConf +import os +import os.path as osp +from pathlib import Path +import subprocess +from typing import Optional + +from fairseq.data.dictionary import Dictionary +from fairseq.dataclass import FairseqDataclass + +script_dir = Path(__file__).resolve().parent +config_path = script_dir / "config" + + +logger = logging.getLogger(__name__) + + +@dataclass +class KaldiInitializerConfig(FairseqDataclass): + data_dir: str = MISSING + fst_dir: Optional[str] = None + in_labels: str = MISSING + out_labels: Optional[str] = None + wav2letter_lexicon: Optional[str] = None + lm_arpa: str = MISSING + kaldi_root: str = MISSING + blank_symbol: str = "" + silence_symbol: Optional[str] = None + + +def create_units(fst_dir: Path, in_labels: str, vocab: Dictionary) -> Path: + in_units_file = fst_dir / f"kaldi_dict.{in_labels}.txt" + if not in_units_file.exists(): + + logger.info(f"Creating {in_units_file}") + + with open(in_units_file, "w") as f: + print(" 0", file=f) + i = 1 + for symb in vocab.symbols[vocab.nspecial :]: + if not symb.startswith("madeupword"): + print(f"{symb} {i}", file=f) + i += 1 + return in_units_file + + +def create_lexicon( + cfg: KaldiInitializerConfig, + fst_dir: Path, + unique_label: str, + in_units_file: Path, + out_words_file: Path, +) -> (Path, Path): + + disambig_in_units_file = fst_dir / f"kaldi_dict.{cfg.in_labels}_disambig.txt" + lexicon_file = fst_dir / f"kaldi_lexicon.{unique_label}.txt" + disambig_lexicon_file = fst_dir / f"kaldi_lexicon.{unique_label}_disambig.txt" + if ( + not lexicon_file.exists() + or not disambig_lexicon_file.exists() + or not disambig_in_units_file.exists() + ): + logger.info(f"Creating {lexicon_file} (in units file: {in_units_file})") + + assert cfg.wav2letter_lexicon is not None or cfg.in_labels == cfg.out_labels + + if cfg.wav2letter_lexicon is not None: + lm_words = set() + with open(out_words_file, "r") as lm_dict_f: + for line in lm_dict_f: + lm_words.add(line.split()[0]) + + num_skipped = 0 + total = 0 + with open(cfg.wav2letter_lexicon, "r") as w2l_lex_f, open( + lexicon_file, "w" + ) as out_f: + for line in w2l_lex_f: + items = line.rstrip().split("\t") + assert len(items) == 2, items + if items[0] in lm_words: + print(items[0], items[1], file=out_f) + else: + num_skipped += 1 + logger.debug( + f"Skipping word {items[0]} as it was not found in LM" + ) + total += 1 + if num_skipped > 0: + logger.warning( + f"Skipped {num_skipped} out of {total} words as they were not found in LM" + ) + else: + with open(in_units_file, "r") as in_f, open(lexicon_file, "w") as out_f: + for line in in_f: + symb = line.split()[0] + if symb != "" and symb != "" and symb != "": + print(symb, symb, file=out_f) + + lex_disambig_path = ( + Path(cfg.kaldi_root) / "egs/wsj/s5/utils/add_lex_disambig.pl" + ) + res = subprocess.run( + [lex_disambig_path, lexicon_file, disambig_lexicon_file], + check=True, + capture_output=True, + ) + ndisambig = int(res.stdout) + disamib_path = Path(cfg.kaldi_root) / "egs/wsj/s5/utils/add_disambig.pl" + res = subprocess.run( + [disamib_path, "--include-zero", in_units_file, str(ndisambig)], + check=True, + capture_output=True, + ) + with open(disambig_in_units_file, "wb") as f: + f.write(res.stdout) + + return disambig_lexicon_file, disambig_in_units_file + + +def create_G( + kaldi_root: Path, fst_dir: Path, lm_arpa: Path, arpa_base: str +) -> (Path, Path): + + out_words_file = fst_dir / f"kaldi_dict.{arpa_base}.txt" + grammar_graph = fst_dir / f"G_{arpa_base}.fst" + if not grammar_graph.exists() or not out_words_file.exists(): + logger.info(f"Creating {grammar_graph}") + arpa2fst = kaldi_root / "src/lmbin/arpa2fst" + subprocess.run( + [ + arpa2fst, + "--disambig-symbol=#0", + f"--write-symbol-table={out_words_file}", + lm_arpa, + grammar_graph, + ], + check=True, + ) + return grammar_graph, out_words_file + + +def create_L( + kaldi_root: Path, + fst_dir: Path, + unique_label: str, + lexicon_file: Path, + in_units_file: Path, + out_words_file: Path, +) -> Path: + lexicon_graph = fst_dir / f"L.{unique_label}.fst" + + if not lexicon_graph.exists(): + logger.info(f"Creating {lexicon_graph} (in units: {in_units_file})") + make_lex = kaldi_root / "egs/wsj/s5/utils/make_lexicon_fst.pl" + fstcompile = kaldi_root / "tools/openfst-1.6.7/bin/fstcompile" + fstaddselfloops = kaldi_root / "src/fstbin/fstaddselfloops" + fstarcsort = kaldi_root / "tools/openfst-1.6.7/bin/fstarcsort" + + def write_disambig_symbol(file): + with open(file, "r") as f: + for line in f: + items = line.rstrip().split() + if items[0] == "#0": + out_path = str(file) + "_disamig" + with open(out_path, "w") as out_f: + print(items[1], file=out_f) + return out_path + + return None + + in_disambig_sym = write_disambig_symbol(in_units_file) + assert in_disambig_sym is not None + out_disambig_sym = write_disambig_symbol(out_words_file) + assert out_disambig_sym is not None + + try: + with open(lexicon_graph, "wb") as out_f: + res = subprocess.run( + [make_lex, lexicon_file], capture_output=True, check=True + ) + assert len(res.stderr) == 0, res.stderr.decode("utf-8") + res = subprocess.run( + [ + fstcompile, + f"--isymbols={in_units_file}", + f"--osymbols={out_words_file}", + "--keep_isymbols=false", + "--keep_osymbols=false", + ], + input=res.stdout, + capture_output=True, + ) + assert len(res.stderr) == 0, res.stderr.decode("utf-8") + res = subprocess.run( + [fstaddselfloops, in_disambig_sym, out_disambig_sym], + input=res.stdout, + capture_output=True, + check=True, + ) + res = subprocess.run( + [fstarcsort, "--sort_type=olabel"], + input=res.stdout, + capture_output=True, + check=True, + ) + out_f.write(res.stdout) + except subprocess.CalledProcessError as e: + logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}") + os.remove(lexicon_graph) + raise + except AssertionError: + os.remove(lexicon_graph) + raise + + return lexicon_graph + + +def create_LG( + kaldi_root: Path, + fst_dir: Path, + unique_label: str, + lexicon_graph: Path, + grammar_graph: Path, +) -> Path: + lg_graph = fst_dir / f"LG.{unique_label}.fst" + + if not lg_graph.exists(): + logger.info(f"Creating {lg_graph}") + + fsttablecompose = kaldi_root / "src/fstbin/fsttablecompose" + fstdeterminizestar = kaldi_root / "src/fstbin/fstdeterminizestar" + fstminimizeencoded = kaldi_root / "src/fstbin/fstminimizeencoded" + fstpushspecial = kaldi_root / "src/fstbin/fstpushspecial" + fstarcsort = kaldi_root / "tools/openfst-1.6.7/bin/fstarcsort" + + try: + with open(lg_graph, "wb") as out_f: + res = subprocess.run( + [fsttablecompose, lexicon_graph, grammar_graph], + capture_output=True, + check=True, + ) + res = subprocess.run( + [ + fstdeterminizestar, + "--use-log=true", + ], + input=res.stdout, + capture_output=True, + ) + res = subprocess.run( + [fstminimizeencoded], + input=res.stdout, + capture_output=True, + check=True, + ) + res = subprocess.run( + [fstpushspecial], + input=res.stdout, + capture_output=True, + check=True, + ) + res = subprocess.run( + [fstarcsort, "--sort_type=ilabel"], + input=res.stdout, + capture_output=True, + check=True, + ) + out_f.write(res.stdout) + except subprocess.CalledProcessError as e: + logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}") + os.remove(lg_graph) + raise + + return lg_graph + + +def create_H( + kaldi_root: Path, + fst_dir: Path, + disambig_out_units_file: Path, + in_labels: str, + vocab: Dictionary, + blk_sym: str, + silence_symbol: Optional[str], +) -> (Path, Path, Path): + h_graph = ( + fst_dir / f"H.{in_labels}{'_' + silence_symbol if silence_symbol else ''}.fst" + ) + h_out_units_file = fst_dir / f"kaldi_dict.h_out.{in_labels}.txt" + disambig_in_units_file_int = Path(str(h_graph) + "isym_disambig.int") + disambig_out_units_file_int = Path(str(disambig_out_units_file) + ".int") + if ( + not h_graph.exists() + or not h_out_units_file.exists() + or not disambig_in_units_file_int.exists() + ): + logger.info(f"Creating {h_graph}") + eps_sym = "" + + num_disambig = 0 + osymbols = [] + + with open(disambig_out_units_file, "r") as f, open( + disambig_out_units_file_int, "w" + ) as out_f: + for line in f: + symb, id = line.rstrip().split() + if line.startswith("#"): + num_disambig += 1 + print(id, file=out_f) + else: + if len(osymbols) == 0: + assert symb == eps_sym, symb + osymbols.append((symb, id)) + + i_idx = 0 + isymbols = [(eps_sym, 0)] + + imap = {} + + for i, s in enumerate(vocab.symbols): + i_idx += 1 + isymbols.append((s, i_idx)) + imap[s] = i_idx + + fst_str = [] + + node_idx = 0 + root_node = node_idx + + special_symbols = [blk_sym] + if silence_symbol is not None: + special_symbols.append(silence_symbol) + + for ss in special_symbols: + fst_str.append("{} {} {} {}".format(root_node, root_node, ss, eps_sym)) + + for symbol, _ in osymbols: + if symbol == eps_sym or symbol.startswith("#"): + continue + + node_idx += 1 + # 1. from root to emitting state + fst_str.append("{} {} {} {}".format(root_node, node_idx, symbol, symbol)) + # 2. from emitting state back to root + fst_str.append("{} {} {} {}".format(node_idx, root_node, eps_sym, eps_sym)) + # 3. from emitting state to optional blank state + pre_node = node_idx + node_idx += 1 + for ss in special_symbols: + fst_str.append("{} {} {} {}".format(pre_node, node_idx, ss, eps_sym)) + # 4. from blank state back to root + fst_str.append("{} {} {} {}".format(node_idx, root_node, eps_sym, eps_sym)) + + fst_str.append("{}".format(root_node)) + + fst_str = "\n".join(fst_str) + h_str = str(h_graph) + isym_file = h_str + ".isym" + + with open(isym_file, "w") as f: + for sym, id in isymbols: + f.write("{} {}\n".format(sym, id)) + + with open(h_out_units_file, "w") as f: + for sym, id in osymbols: + f.write("{} {}\n".format(sym, id)) + + with open(disambig_in_units_file_int, "w") as f: + disam_sym_id = len(isymbols) + for _ in range(num_disambig): + f.write("{}\n".format(disam_sym_id)) + disam_sym_id += 1 + + fstcompile = kaldi_root / "tools/openfst-1.6.7/bin/fstcompile" + fstaddselfloops = kaldi_root / "src/fstbin/fstaddselfloops" + fstarcsort = kaldi_root / "tools/openfst-1.6.7/bin/fstarcsort" + + try: + with open(h_graph, "wb") as out_f: + res = subprocess.run( + [ + fstcompile, + f"--isymbols={isym_file}", + f"--osymbols={h_out_units_file}", + "--keep_isymbols=false", + "--keep_osymbols=false", + ], + input=str.encode(fst_str), + capture_output=True, + check=True, + ) + res = subprocess.run( + [ + fstaddselfloops, + disambig_in_units_file_int, + disambig_out_units_file_int, + ], + input=res.stdout, + capture_output=True, + check=True, + ) + res = subprocess.run( + [fstarcsort, "--sort_type=olabel"], + input=res.stdout, + capture_output=True, + check=True, + ) + out_f.write(res.stdout) + except subprocess.CalledProcessError as e: + logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}") + os.remove(h_graph) + raise + return h_graph, h_out_units_file, disambig_in_units_file_int + + +def create_HLGa( + kaldi_root: Path, + fst_dir: Path, + unique_label: str, + h_graph: Path, + lg_graph: Path, + disambig_in_words_file_int: Path, +) -> Path: + hlga_graph = fst_dir / f"HLGa.{unique_label}.fst" + + if not hlga_graph.exists(): + logger.info(f"Creating {hlga_graph}") + + fsttablecompose = kaldi_root / "src/fstbin/fsttablecompose" + fstdeterminizestar = kaldi_root / "src/fstbin/fstdeterminizestar" + fstrmsymbols = kaldi_root / "src/fstbin/fstrmsymbols" + fstrmepslocal = kaldi_root / "src/fstbin/fstrmepslocal" + fstminimizeencoded = kaldi_root / "src/fstbin/fstminimizeencoded" + + try: + with open(hlga_graph, "wb") as out_f: + res = subprocess.run( + [ + fsttablecompose, + h_graph, + lg_graph, + ], + capture_output=True, + check=True, + ) + res = subprocess.run( + [fstdeterminizestar, "--use-log=true"], + input=res.stdout, + capture_output=True, + check=True, + ) + res = subprocess.run( + [fstrmsymbols, disambig_in_words_file_int], + input=res.stdout, + capture_output=True, + check=True, + ) + res = subprocess.run( + [fstrmepslocal], + input=res.stdout, + capture_output=True, + check=True, + ) + res = subprocess.run( + [fstminimizeencoded], + input=res.stdout, + capture_output=True, + check=True, + ) + out_f.write(res.stdout) + except subprocess.CalledProcessError as e: + logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}") + os.remove(hlga_graph) + raise + + return hlga_graph + + +def create_HLa( + kaldi_root: Path, + fst_dir: Path, + unique_label: str, + h_graph: Path, + l_graph: Path, + disambig_in_words_file_int: Path, +) -> Path: + hla_graph = fst_dir / f"HLa.{unique_label}.fst" + + if not hla_graph.exists(): + logger.info(f"Creating {hla_graph}") + + fsttablecompose = kaldi_root / "src/fstbin/fsttablecompose" + fstdeterminizestar = kaldi_root / "src/fstbin/fstdeterminizestar" + fstrmsymbols = kaldi_root / "src/fstbin/fstrmsymbols" + fstrmepslocal = kaldi_root / "src/fstbin/fstrmepslocal" + fstminimizeencoded = kaldi_root / "src/fstbin/fstminimizeencoded" + + try: + with open(hla_graph, "wb") as out_f: + res = subprocess.run( + [ + fsttablecompose, + h_graph, + l_graph, + ], + capture_output=True, + check=True, + ) + res = subprocess.run( + [fstdeterminizestar, "--use-log=true"], + input=res.stdout, + capture_output=True, + check=True, + ) + res = subprocess.run( + [fstrmsymbols, disambig_in_words_file_int], + input=res.stdout, + capture_output=True, + check=True, + ) + res = subprocess.run( + [fstrmepslocal], + input=res.stdout, + capture_output=True, + check=True, + ) + res = subprocess.run( + [fstminimizeencoded], + input=res.stdout, + capture_output=True, + check=True, + ) + out_f.write(res.stdout) + except subprocess.CalledProcessError as e: + logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}") + os.remove(hla_graph) + raise + + return hla_graph + + +def create_HLG( + kaldi_root: Path, + fst_dir: Path, + unique_label: str, + hlga_graph: Path, + prefix: str = "HLG", +) -> Path: + hlg_graph = fst_dir / f"{prefix}.{unique_label}.fst" + + if not hlg_graph.exists(): + logger.info(f"Creating {hlg_graph}") + + add_self_loop = script_dir / "add-self-loop-simple" + kaldi_src = kaldi_root / "src" + kaldi_lib = kaldi_src / "lib" + + try: + if not add_self_loop.exists(): + fst_include = kaldi_root / "tools/openfst-1.6.7/include" + add_self_loop_src = script_dir / "add-self-loop-simple.cc" + + subprocess.run( + [ + "c++", + f"-I{kaldi_src}", + f"-I{fst_include}", + f"-L{kaldi_lib}", + add_self_loop_src, + "-lkaldi-base", + "-lkaldi-fstext", + "-o", + add_self_loop, + ], + check=True, + ) + + my_env = os.environ.copy() + my_env["LD_LIBRARY_PATH"] = f"{kaldi_lib}:{my_env['LD_LIBRARY_PATH']}" + + subprocess.run( + [ + add_self_loop, + hlga_graph, + hlg_graph, + ], + check=True, + capture_output=True, + env=my_env, + ) + except subprocess.CalledProcessError as e: + logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}") + raise + + return hlg_graph + + +def initalize_kaldi(cfg: KaldiInitializerConfig) -> Path: + if cfg.fst_dir is None: + cfg.fst_dir = osp.join(cfg.data_dir, "kaldi") + if cfg.out_labels is None: + cfg.out_labels = cfg.in_labels + + kaldi_root = Path(cfg.kaldi_root) + data_dir = Path(cfg.data_dir) + fst_dir = Path(cfg.fst_dir) + fst_dir.mkdir(parents=True, exist_ok=True) + + arpa_base = osp.splitext(osp.basename(cfg.lm_arpa))[0] + unique_label = f"{cfg.in_labels}.{arpa_base}" + + with open(data_dir / f"dict.{cfg.in_labels}.txt", "r") as f: + vocab = Dictionary.load(f) + + in_units_file = create_units(fst_dir, cfg.in_labels, vocab) + + grammar_graph, out_words_file = create_G( + kaldi_root, fst_dir, Path(cfg.lm_arpa), arpa_base + ) + + disambig_lexicon_file, disambig_L_in_units_file = create_lexicon( + cfg, fst_dir, unique_label, in_units_file, out_words_file + ) + + h_graph, h_out_units_file, disambig_in_units_file_int = create_H( + kaldi_root, + fst_dir, + disambig_L_in_units_file, + cfg.in_labels, + vocab, + cfg.blank_symbol, + cfg.silence_symbol, + ) + lexicon_graph = create_L( + kaldi_root, + fst_dir, + unique_label, + disambig_lexicon_file, + disambig_L_in_units_file, + out_words_file, + ) + lg_graph = create_LG( + kaldi_root, fst_dir, unique_label, lexicon_graph, grammar_graph + ) + hlga_graph = create_HLGa( + kaldi_root, fst_dir, unique_label, h_graph, lg_graph, disambig_in_units_file_int + ) + hlg_graph = create_HLG(kaldi_root, fst_dir, unique_label, hlga_graph) + + # for debugging + # hla_graph = create_HLa(kaldi_root, fst_dir, unique_label, h_graph, lexicon_graph, disambig_in_units_file_int) + # hl_graph = create_HLG(kaldi_root, fst_dir, unique_label, hla_graph, prefix="HL_looped") + # create_HLG(kaldi_root, fst_dir, "phnc", h_graph, prefix="H_looped") + + return hlg_graph + + +@hydra.main(config_path=config_path, config_name="kaldi_initializer") +def cli_main(cfg: KaldiInitializerConfig) -> None: + container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True) + cfg = OmegaConf.create(container) + OmegaConf.set_struct(cfg, True) + initalize_kaldi(cfg) + + +if __name__ == "__main__": + + logging.root.setLevel(logging.INFO) + logging.basicConfig(level=logging.INFO) + + try: + from hydra._internal.utils import ( + get_args, + ) # pylint: disable=import-outside-toplevel + + cfg_name = get_args().config_name or "kaldi_initializer" + except ImportError: + logger.warning("Failed to get config name from hydra args") + cfg_name = "kaldi_initializer" + + cs = ConfigStore.instance() + cs.store(name=cfg_name, node=KaldiInitializerConfig) + + cli_main() diff --git a/fairseq/examples/speech_recognition/models/__init__.py b/fairseq/examples/speech_recognition/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..54b5a1c31243e55d384f80ef9514461cd35b15c6 --- /dev/null +++ b/fairseq/examples/speech_recognition/models/__init__.py @@ -0,0 +1,8 @@ +import importlib +import os + + +for file in sorted(os.listdir(os.path.dirname(__file__))): + if file.endswith(".py") and not file.startswith("_"): + model_name = file[: file.find(".py")] + importlib.import_module("examples.speech_recognition.models." + model_name) diff --git a/fairseq/examples/speech_recognition/models/vggtransformer.py b/fairseq/examples/speech_recognition/models/vggtransformer.py new file mode 100644 index 0000000000000000000000000000000000000000..bca0ae59a8cbe2b7c337e395021c883a61d101ee --- /dev/null +++ b/fairseq/examples/speech_recognition/models/vggtransformer.py @@ -0,0 +1,1020 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import math +from collections.abc import Iterable + +import torch +import torch.nn as nn +from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask +from fairseq import utils +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderDecoderModel, + FairseqEncoderModel, + FairseqIncrementalDecoder, + register_model, + register_model_architecture, +) +from fairseq.modules import ( + LinearizedConvolution, + TransformerDecoderLayer, + TransformerEncoderLayer, + VGGBlock, +) + + +@register_model("asr_vggtransformer") +class VGGTransformerModel(FairseqEncoderDecoderModel): + """ + Transformers with convolutional context for ASR + https://arxiv.org/abs/1904.11660 + """ + + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--input-feat-per-channel", + type=int, + metavar="N", + help="encoder input dimension per input channel", + ) + parser.add_argument( + "--vggblock-enc-config", + type=str, + metavar="EXPR", + help=""" + an array of tuples each containing the configuration of one vggblock: + [(out_channels, + conv_kernel_size, + pooling_kernel_size, + num_conv_layers, + use_layer_norm), ...]) + """, + ) + parser.add_argument( + "--transformer-enc-config", + type=str, + metavar="EXPR", + help="""" + a tuple containing the configuration of the encoder transformer layers + configurations: + [(input_dim, + num_heads, + ffn_dim, + normalize_before, + dropout, + attention_dropout, + relu_dropout), ...]') + """, + ) + parser.add_argument( + "--enc-output-dim", + type=int, + metavar="N", + help=""" + encoder output dimension, can be None. If specified, projecting the + transformer output to the specified dimension""", + ) + parser.add_argument( + "--in-channels", + type=int, + metavar="N", + help="number of encoder input channels", + ) + parser.add_argument( + "--tgt-embed-dim", + type=int, + metavar="N", + help="embedding dimension of the decoder target tokens", + ) + parser.add_argument( + "--transformer-dec-config", + type=str, + metavar="EXPR", + help=""" + a tuple containing the configuration of the decoder transformer layers + configurations: + [(input_dim, + num_heads, + ffn_dim, + normalize_before, + dropout, + attention_dropout, + relu_dropout), ...] + """, + ) + parser.add_argument( + "--conv-dec-config", + type=str, + metavar="EXPR", + help=""" + an array of tuples for the decoder 1-D convolution config + [(out_channels, conv_kernel_size, use_layer_norm), ...]""", + ) + + @classmethod + def build_encoder(cls, args, task): + return VGGTransformerEncoder( + input_feat_per_channel=args.input_feat_per_channel, + vggblock_config=eval(args.vggblock_enc_config), + transformer_config=eval(args.transformer_enc_config), + encoder_output_dim=args.enc_output_dim, + in_channels=args.in_channels, + ) + + @classmethod + def build_decoder(cls, args, task): + return TransformerDecoder( + dictionary=task.target_dictionary, + embed_dim=args.tgt_embed_dim, + transformer_config=eval(args.transformer_dec_config), + conv_config=eval(args.conv_dec_config), + encoder_output_dim=args.enc_output_dim, + ) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + # make sure that all args are properly defaulted + # (in case there are any new ones) + base_architecture(args) + + encoder = cls.build_encoder(args, task) + decoder = cls.build_decoder(args, task) + return cls(encoder, decoder) + + def get_normalized_probs(self, net_output, log_probs, sample=None): + # net_output['encoder_out'] is a (B, T, D) tensor + lprobs = super().get_normalized_probs(net_output, log_probs, sample) + lprobs.batch_first = True + return lprobs + + +DEFAULT_ENC_VGGBLOCK_CONFIG = ((32, 3, 2, 2, False),) * 2 +DEFAULT_ENC_TRANSFORMER_CONFIG = ((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2 +# 256: embedding dimension +# 4: number of heads +# 1024: FFN +# True: apply layerNorm before (dropout + resiaul) instead of after +# 0.2 (dropout): dropout after MultiheadAttention and second FC +# 0.2 (attention_dropout): dropout in MultiheadAttention +# 0.2 (relu_dropout): dropout after ReLu +DEFAULT_DEC_TRANSFORMER_CONFIG = ((256, 2, 1024, True, 0.2, 0.2, 0.2),) * 2 +DEFAULT_DEC_CONV_CONFIG = ((256, 3, True),) * 2 + + +# TODO: repace transformer encoder config from one liner +# to explicit args to get rid of this transformation +def prepare_transformer_encoder_params( + input_dim, + num_heads, + ffn_dim, + normalize_before, + dropout, + attention_dropout, + relu_dropout, +): + args = argparse.Namespace() + args.encoder_embed_dim = input_dim + args.encoder_attention_heads = num_heads + args.attention_dropout = attention_dropout + args.dropout = dropout + args.activation_dropout = relu_dropout + args.encoder_normalize_before = normalize_before + args.encoder_ffn_embed_dim = ffn_dim + return args + + +def prepare_transformer_decoder_params( + input_dim, + num_heads, + ffn_dim, + normalize_before, + dropout, + attention_dropout, + relu_dropout, +): + args = argparse.Namespace() + args.encoder_embed_dim = None + args.decoder_embed_dim = input_dim + args.decoder_attention_heads = num_heads + args.attention_dropout = attention_dropout + args.dropout = dropout + args.activation_dropout = relu_dropout + args.decoder_normalize_before = normalize_before + args.decoder_ffn_embed_dim = ffn_dim + return args + + +class VGGTransformerEncoder(FairseqEncoder): + """VGG + Transformer encoder""" + + def __init__( + self, + input_feat_per_channel, + vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG, + transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG, + encoder_output_dim=512, + in_channels=1, + transformer_context=None, + transformer_sampling=None, + ): + """constructor for VGGTransformerEncoder + + Args: + - input_feat_per_channel: feature dim (not including stacked, + just base feature) + - in_channel: # input channels (e.g., if stack 8 feature vector + together, this is 8) + - vggblock_config: configuration of vggblock, see comments on + DEFAULT_ENC_VGGBLOCK_CONFIG + - transformer_config: configuration of transformer layer, see comments + on DEFAULT_ENC_TRANSFORMER_CONFIG + - encoder_output_dim: final transformer output embedding dimension + - transformer_context: (left, right) if set, self-attention will be focused + on (t-left, t+right) + - transformer_sampling: an iterable of int, must match with + len(transformer_config), transformer_sampling[i] indicates sampling + factor for i-th transformer layer, after multihead att and feedfoward + part + """ + super().__init__(None) + + self.num_vggblocks = 0 + if vggblock_config is not None: + if not isinstance(vggblock_config, Iterable): + raise ValueError("vggblock_config is not iterable") + self.num_vggblocks = len(vggblock_config) + + self.conv_layers = nn.ModuleList() + self.in_channels = in_channels + self.input_dim = input_feat_per_channel + self.pooling_kernel_sizes = [] + + if vggblock_config is not None: + for _, config in enumerate(vggblock_config): + ( + out_channels, + conv_kernel_size, + pooling_kernel_size, + num_conv_layers, + layer_norm, + ) = config + self.conv_layers.append( + VGGBlock( + in_channels, + out_channels, + conv_kernel_size, + pooling_kernel_size, + num_conv_layers, + input_dim=input_feat_per_channel, + layer_norm=layer_norm, + ) + ) + self.pooling_kernel_sizes.append(pooling_kernel_size) + in_channels = out_channels + input_feat_per_channel = self.conv_layers[-1].output_dim + + transformer_input_dim = self.infer_conv_output_dim( + self.in_channels, self.input_dim + ) + # transformer_input_dim is the output dimension of VGG part + + self.validate_transformer_config(transformer_config) + self.transformer_context = self.parse_transformer_context(transformer_context) + self.transformer_sampling = self.parse_transformer_sampling( + transformer_sampling, len(transformer_config) + ) + + self.transformer_layers = nn.ModuleList() + + if transformer_input_dim != transformer_config[0][0]: + self.transformer_layers.append( + Linear(transformer_input_dim, transformer_config[0][0]) + ) + self.transformer_layers.append( + TransformerEncoderLayer( + prepare_transformer_encoder_params(*transformer_config[0]) + ) + ) + + for i in range(1, len(transformer_config)): + if transformer_config[i - 1][0] != transformer_config[i][0]: + self.transformer_layers.append( + Linear(transformer_config[i - 1][0], transformer_config[i][0]) + ) + self.transformer_layers.append( + TransformerEncoderLayer( + prepare_transformer_encoder_params(*transformer_config[i]) + ) + ) + + self.encoder_output_dim = encoder_output_dim + self.transformer_layers.extend( + [ + Linear(transformer_config[-1][0], encoder_output_dim), + LayerNorm(encoder_output_dim), + ] + ) + + def forward(self, src_tokens, src_lengths, **kwargs): + """ + src_tokens: padded tensor (B, T, C * feat) + src_lengths: tensor of original lengths of input utterances (B,) + """ + bsz, max_seq_len, _ = src_tokens.size() + x = src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim) + x = x.transpose(1, 2).contiguous() + # (B, C, T, feat) + + for layer_idx in range(len(self.conv_layers)): + x = self.conv_layers[layer_idx](x) + + bsz, _, output_seq_len, _ = x.size() + + # (B, C, T, feat) -> (B, T, C, feat) -> (T, B, C, feat) -> (T, B, C * feat) + x = x.transpose(1, 2).transpose(0, 1) + x = x.contiguous().view(output_seq_len, bsz, -1) + + input_lengths = src_lengths.clone() + for s in self.pooling_kernel_sizes: + input_lengths = (input_lengths.float() / s).ceil().long() + + encoder_padding_mask, _ = lengths_to_encoder_padding_mask( + input_lengths, batch_first=True + ) + if not encoder_padding_mask.any(): + encoder_padding_mask = None + + subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5) + attn_mask = self.lengths_to_attn_mask(input_lengths, subsampling_factor) + + transformer_layer_idx = 0 + + for layer_idx in range(len(self.transformer_layers)): + + if isinstance(self.transformer_layers[layer_idx], TransformerEncoderLayer): + x = self.transformer_layers[layer_idx]( + x, encoder_padding_mask, attn_mask + ) + + if self.transformer_sampling[transformer_layer_idx] != 1: + sampling_factor = self.transformer_sampling[transformer_layer_idx] + x, encoder_padding_mask, attn_mask = self.slice( + x, encoder_padding_mask, attn_mask, sampling_factor + ) + + transformer_layer_idx += 1 + + else: + x = self.transformer_layers[layer_idx](x) + + # encoder_padding_maks is a (T x B) tensor, its [t, b] elements indicate + # whether encoder_output[t, b] is valid or not (valid=0, invalid=1) + + return { + "encoder_out": x, # (T, B, C) + "encoder_padding_mask": encoder_padding_mask.t() + if encoder_padding_mask is not None + else None, + # (B, T) --> (T, B) + } + + def infer_conv_output_dim(self, in_channels, input_dim): + sample_seq_len = 200 + sample_bsz = 10 + x = torch.randn(sample_bsz, in_channels, sample_seq_len, input_dim) + for i, _ in enumerate(self.conv_layers): + x = self.conv_layers[i](x) + x = x.transpose(1, 2) + mb, seq = x.size()[:2] + return x.contiguous().view(mb, seq, -1).size(-1) + + def validate_transformer_config(self, transformer_config): + for config in transformer_config: + input_dim, num_heads = config[:2] + if input_dim % num_heads != 0: + msg = ( + "ERROR in transformer config {}: ".format(config) + + "input dimension {} ".format(input_dim) + + "not dividable by number of heads {}".format(num_heads) + ) + raise ValueError(msg) + + def parse_transformer_context(self, transformer_context): + """ + transformer_context can be the following: + - None; indicates no context is used, i.e., + transformer can access full context + - a tuple/list of two int; indicates left and right context, + any number <0 indicates infinite context + * e.g., (5, 6) indicates that for query at x_t, transformer can + access [t-5, t+6] (inclusive) + * e.g., (-1, 6) indicates that for query at x_t, transformer can + access [0, t+6] (inclusive) + """ + if transformer_context is None: + return None + + if not isinstance(transformer_context, Iterable): + raise ValueError("transformer context must be Iterable if it is not None") + + if len(transformer_context) != 2: + raise ValueError("transformer context must have length 2") + + left_context = transformer_context[0] + if left_context < 0: + left_context = None + + right_context = transformer_context[1] + if right_context < 0: + right_context = None + + if left_context is None and right_context is None: + return None + + return (left_context, right_context) + + def parse_transformer_sampling(self, transformer_sampling, num_layers): + """ + parsing transformer sampling configuration + + Args: + - transformer_sampling, accepted input: + * None, indicating no sampling + * an Iterable with int (>0) as element + - num_layers, expected number of transformer layers, must match with + the length of transformer_sampling if it is not None + + Returns: + - A tuple with length num_layers + """ + if transformer_sampling is None: + return (1,) * num_layers + + if not isinstance(transformer_sampling, Iterable): + raise ValueError( + "transformer_sampling must be an iterable if it is not None" + ) + + if len(transformer_sampling) != num_layers: + raise ValueError( + "transformer_sampling {} does not match with the number " + "of layers {}".format(transformer_sampling, num_layers) + ) + + for layer, value in enumerate(transformer_sampling): + if not isinstance(value, int): + raise ValueError("Invalid value in transformer_sampling: ") + if value < 1: + raise ValueError( + "{} layer's subsampling is {}.".format(layer, value) + + " This is not allowed! " + ) + return transformer_sampling + + def slice(self, embedding, padding_mask, attn_mask, sampling_factor): + """ + embedding is a (T, B, D) tensor + padding_mask is a (B, T) tensor or None + attn_mask is a (T, T) tensor or None + """ + embedding = embedding[::sampling_factor, :, :] + if padding_mask is not None: + padding_mask = padding_mask[:, ::sampling_factor] + if attn_mask is not None: + attn_mask = attn_mask[::sampling_factor, ::sampling_factor] + + return embedding, padding_mask, attn_mask + + def lengths_to_attn_mask(self, input_lengths, subsampling_factor=1): + """ + create attention mask according to sequence lengths and transformer + context + + Args: + - input_lengths: (B, )-shape Int/Long tensor; input_lengths[b] is + the length of b-th sequence + - subsampling_factor: int + * Note that the left_context and right_context is specified in + the input frame-level while input to transformer may already + go through subsampling (e.g., the use of striding in vggblock) + we use subsampling_factor to scale the left/right context + + Return: + - a (T, T) binary tensor or None, where T is max(input_lengths) + * if self.transformer_context is None, None + * if left_context is None, + * attn_mask[t, t + right_context + 1:] = 1 + * others = 0 + * if right_context is None, + * attn_mask[t, 0:t - left_context] = 1 + * others = 0 + * elsif + * attn_mask[t, t - left_context: t + right_context + 1] = 0 + * others = 1 + """ + if self.transformer_context is None: + return None + + maxT = torch.max(input_lengths).item() + attn_mask = torch.zeros(maxT, maxT) + + left_context = self.transformer_context[0] + right_context = self.transformer_context[1] + if left_context is not None: + left_context = math.ceil(self.transformer_context[0] / subsampling_factor) + if right_context is not None: + right_context = math.ceil(self.transformer_context[1] / subsampling_factor) + + for t in range(maxT): + if left_context is not None: + st = 0 + en = max(st, t - left_context) + attn_mask[t, st:en] = 1 + if right_context is not None: + st = t + right_context + 1 + st = min(st, maxT - 1) + attn_mask[t, st:] = 1 + + return attn_mask.to(input_lengths.device) + + def reorder_encoder_out(self, encoder_out, new_order): + encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select( + 1, new_order + ) + if encoder_out["encoder_padding_mask"] is not None: + encoder_out["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ].index_select(1, new_order) + return encoder_out + + +class TransformerDecoder(FairseqIncrementalDecoder): + """ + Transformer decoder consisting of *args.decoder_layers* layers. Each layer + is a :class:`TransformerDecoderLayer`. + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): decoding dictionary + embed_tokens (torch.nn.Embedding): output embedding + no_encoder_attn (bool, optional): whether to attend to encoder outputs. + Default: ``False`` + left_pad (bool, optional): whether the input is left-padded. Default: + ``False`` + """ + + def __init__( + self, + dictionary, + embed_dim=512, + transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG, + conv_config=DEFAULT_DEC_CONV_CONFIG, + encoder_output_dim=512, + ): + + super().__init__(dictionary) + vocab_size = len(dictionary) + self.padding_idx = dictionary.pad() + self.embed_tokens = Embedding(vocab_size, embed_dim, self.padding_idx) + + self.conv_layers = nn.ModuleList() + for i in range(len(conv_config)): + out_channels, kernel_size, layer_norm = conv_config[i] + if i == 0: + conv_layer = LinearizedConv1d( + embed_dim, out_channels, kernel_size, padding=kernel_size - 1 + ) + else: + conv_layer = LinearizedConv1d( + conv_config[i - 1][0], + out_channels, + kernel_size, + padding=kernel_size - 1, + ) + self.conv_layers.append(conv_layer) + if layer_norm: + self.conv_layers.append(nn.LayerNorm(out_channels)) + self.conv_layers.append(nn.ReLU()) + + self.layers = nn.ModuleList() + if conv_config[-1][0] != transformer_config[0][0]: + self.layers.append(Linear(conv_config[-1][0], transformer_config[0][0])) + self.layers.append( + TransformerDecoderLayer( + prepare_transformer_decoder_params(*transformer_config[0]) + ) + ) + + for i in range(1, len(transformer_config)): + if transformer_config[i - 1][0] != transformer_config[i][0]: + self.layers.append( + Linear(transformer_config[i - 1][0], transformer_config[i][0]) + ) + self.layers.append( + TransformerDecoderLayer( + prepare_transformer_decoder_params(*transformer_config[i]) + ) + ) + self.fc_out = Linear(transformer_config[-1][0], vocab_size) + + def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for input feeding/teacher forcing + encoder_out (Tensor, optional): output from the encoder, used for + encoder-side attention + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + Returns: + tuple: + - the last decoder layer's output of shape `(batch, tgt_len, + vocab)` + - the last decoder layer's attention weights of shape `(batch, + tgt_len, src_len)` + """ + target_padding_mask = ( + (prev_output_tokens == self.padding_idx).to(prev_output_tokens.device) + if incremental_state is None + else None + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + + # embed tokens + x = self.embed_tokens(prev_output_tokens) + + # B x T x C -> T x B x C + x = self._transpose_if_training(x, incremental_state) + + for layer in self.conv_layers: + if isinstance(layer, LinearizedConvolution): + x = layer(x, incremental_state) + else: + x = layer(x) + + # B x T x C -> T x B x C + x = self._transpose_if_inference(x, incremental_state) + + # decoder layers + for layer in self.layers: + if isinstance(layer, TransformerDecoderLayer): + x, *_ = layer( + x, + (encoder_out["encoder_out"] if encoder_out is not None else None), + ( + encoder_out["encoder_padding_mask"].t() + if encoder_out["encoder_padding_mask"] is not None + else None + ), + incremental_state, + self_attn_mask=( + self.buffered_future_mask(x) + if incremental_state is None + else None + ), + self_attn_padding_mask=( + target_padding_mask if incremental_state is None else None + ), + ) + else: + x = layer(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + x = self.fc_out(x) + + return x, None + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + if ( + not hasattr(self, "_future_mask") + or self._future_mask is None + or self._future_mask.device != tensor.device + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 + ) + if self._future_mask.size(0) < dim: + self._future_mask = torch.triu( + utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1 + ) + return self._future_mask[:dim, :dim] + + def _transpose_if_training(self, x, incremental_state): + if incremental_state is None: + x = x.transpose(0, 1) + return x + + def _transpose_if_inference(self, x, incremental_state): + if incremental_state: + x = x.transpose(0, 1) + return x + + +@register_model("asr_vggtransformer_encoder") +class VGGTransformerEncoderModel(FairseqEncoderModel): + def __init__(self, encoder): + super().__init__(encoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--input-feat-per-channel", + type=int, + metavar="N", + help="encoder input dimension per input channel", + ) + parser.add_argument( + "--vggblock-enc-config", + type=str, + metavar="EXPR", + help=""" + an array of tuples each containing the configuration of one vggblock + [(out_channels, conv_kernel_size, pooling_kernel_size,num_conv_layers), ...] + """, + ) + parser.add_argument( + "--transformer-enc-config", + type=str, + metavar="EXPR", + help=""" + a tuple containing the configuration of the Transformer layers + configurations: + [(input_dim, + num_heads, + ffn_dim, + normalize_before, + dropout, + attention_dropout, + relu_dropout), ]""", + ) + parser.add_argument( + "--enc-output-dim", + type=int, + metavar="N", + help="encoder output dimension, projecting the LSTM output", + ) + parser.add_argument( + "--in-channels", + type=int, + metavar="N", + help="number of encoder input channels", + ) + parser.add_argument( + "--transformer-context", + type=str, + metavar="EXPR", + help=""" + either None or a tuple of two ints, indicating left/right context a + transformer can have access to""", + ) + parser.add_argument( + "--transformer-sampling", + type=str, + metavar="EXPR", + help=""" + either None or a tuple of ints, indicating sampling factor in each layer""", + ) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + base_architecture_enconly(args) + encoder = VGGTransformerEncoderOnly( + vocab_size=len(task.target_dictionary), + input_feat_per_channel=args.input_feat_per_channel, + vggblock_config=eval(args.vggblock_enc_config), + transformer_config=eval(args.transformer_enc_config), + encoder_output_dim=args.enc_output_dim, + in_channels=args.in_channels, + transformer_context=eval(args.transformer_context), + transformer_sampling=eval(args.transformer_sampling), + ) + return cls(encoder) + + def get_normalized_probs(self, net_output, log_probs, sample=None): + # net_output['encoder_out'] is a (T, B, D) tensor + lprobs = super().get_normalized_probs(net_output, log_probs, sample) + # lprobs is a (T, B, D) tensor + # we need to transoose to get (B, T, D) tensor + lprobs = lprobs.transpose(0, 1).contiguous() + lprobs.batch_first = True + return lprobs + + +class VGGTransformerEncoderOnly(VGGTransformerEncoder): + def __init__( + self, + vocab_size, + input_feat_per_channel, + vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG, + transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG, + encoder_output_dim=512, + in_channels=1, + transformer_context=None, + transformer_sampling=None, + ): + super().__init__( + input_feat_per_channel=input_feat_per_channel, + vggblock_config=vggblock_config, + transformer_config=transformer_config, + encoder_output_dim=encoder_output_dim, + in_channels=in_channels, + transformer_context=transformer_context, + transformer_sampling=transformer_sampling, + ) + self.fc_out = Linear(self.encoder_output_dim, vocab_size) + + def forward(self, src_tokens, src_lengths, **kwargs): + """ + src_tokens: padded tensor (B, T, C * feat) + src_lengths: tensor of original lengths of input utterances (B,) + """ + + enc_out = super().forward(src_tokens, src_lengths) + x = self.fc_out(enc_out["encoder_out"]) + # x = F.log_softmax(x, dim=-1) + # Note: no need this line, because model.get_normalized_prob will call + # log_softmax + return { + "encoder_out": x, # (T, B, C) + "encoder_padding_mask": enc_out["encoder_padding_mask"], # (T, B) + } + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return (1e6, 1e6) # an arbitrary large number + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + # nn.init.uniform_(m.weight, -0.1, 0.1) + # nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def Linear(in_features, out_features, bias=True, dropout=0): + """Linear layer (input: N x T x C)""" + m = nn.Linear(in_features, out_features, bias=bias) + # m.weight.data.uniform_(-0.1, 0.1) + # if bias: + # m.bias.data.uniform_(-0.1, 0.1) + return m + + +def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs): + """Weight-normalized Conv1d layer optimized for decoding""" + m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs) + std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) + nn.init.normal_(m.weight, mean=0, std=std) + nn.init.constant_(m.bias, 0) + return nn.utils.weight_norm(m, dim=2) + + +def LayerNorm(embedding_dim): + m = nn.LayerNorm(embedding_dim) + return m + + +# seq2seq models +def base_architecture(args): + args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 40) + args.vggblock_enc_config = getattr( + args, "vggblock_enc_config", DEFAULT_ENC_VGGBLOCK_CONFIG + ) + args.transformer_enc_config = getattr( + args, "transformer_enc_config", DEFAULT_ENC_TRANSFORMER_CONFIG + ) + args.enc_output_dim = getattr(args, "enc_output_dim", 512) + args.in_channels = getattr(args, "in_channels", 1) + args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128) + args.transformer_dec_config = getattr( + args, "transformer_dec_config", DEFAULT_ENC_TRANSFORMER_CONFIG + ) + args.conv_dec_config = getattr(args, "conv_dec_config", DEFAULT_DEC_CONV_CONFIG) + args.transformer_context = getattr(args, "transformer_context", "None") + + +@register_model_architecture("asr_vggtransformer", "vggtransformer_1") +def vggtransformer_1(args): + args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) + args.vggblock_enc_config = getattr( + args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]" + ) + args.transformer_enc_config = getattr( + args, + "transformer_enc_config", + "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 14", + ) + args.enc_output_dim = getattr(args, "enc_output_dim", 1024) + args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128) + args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4") + args.transformer_dec_config = getattr( + args, + "transformer_dec_config", + "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 4", + ) + + +@register_model_architecture("asr_vggtransformer", "vggtransformer_2") +def vggtransformer_2(args): + args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) + args.vggblock_enc_config = getattr( + args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]" + ) + args.transformer_enc_config = getattr( + args, + "transformer_enc_config", + "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16", + ) + args.enc_output_dim = getattr(args, "enc_output_dim", 1024) + args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512) + args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4") + args.transformer_dec_config = getattr( + args, + "transformer_dec_config", + "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 6", + ) + + +@register_model_architecture("asr_vggtransformer", "vggtransformer_base") +def vggtransformer_base(args): + args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) + args.vggblock_enc_config = getattr( + args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]" + ) + args.transformer_enc_config = getattr( + args, "transformer_enc_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 12" + ) + + args.enc_output_dim = getattr(args, "enc_output_dim", 512) + args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512) + args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4") + args.transformer_dec_config = getattr( + args, "transformer_dec_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 6" + ) + # Size estimations: + # Encoder: + # - vggblock param: 64*1*3*3 + 64*64*3*3 + 128*64*3*3 + 128*128*3 = 258K + # Transformer: + # - input dimension adapter: 2560 x 512 -> 1.31M + # - transformer_layers (x12) --> 37.74M + # * MultiheadAttention: 512*512*3 (in_proj) + 512*512 (out_proj) = 1.048M + # * FFN weight: 512*2048*2 = 2.097M + # - output dimension adapter: 512 x 512 -> 0.26 M + # Decoder: + # - LinearizedConv1d: 512 * 256 * 3 + 256 * 256 * 3 * 3 + # - transformer_layer: (x6) --> 25.16M + # * MultiheadAttention (self-attention): 512*512*3 + 512*512 = 1.048M + # * MultiheadAttention (encoder-attention): 512*512*3 + 512*512 = 1.048M + # * FFN: 512*2048*2 = 2.097M + # Final FC: + # - FC: 512*5000 = 256K (assuming vocab size 5K) + # In total: + # ~65 M + + +# CTC models +def base_architecture_enconly(args): + args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 40) + args.vggblock_enc_config = getattr( + args, "vggblock_enc_config", "[(32, 3, 2, 2, True)] * 2" + ) + args.transformer_enc_config = getattr( + args, "transformer_enc_config", "((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2" + ) + args.enc_output_dim = getattr(args, "enc_output_dim", 512) + args.in_channels = getattr(args, "in_channels", 1) + args.transformer_context = getattr(args, "transformer_context", "None") + args.transformer_sampling = getattr(args, "transformer_sampling", "None") + + +@register_model_architecture("asr_vggtransformer_encoder", "vggtransformer_enc_1") +def vggtransformer_enc_1(args): + # vggtransformer_1 is the same as vggtransformer_enc_big, except the number + # of layers is increased to 16 + # keep it here for backward compatiablity purpose + args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) + args.vggblock_enc_config = getattr( + args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]" + ) + args.transformer_enc_config = getattr( + args, + "transformer_enc_config", + "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16", + ) + args.enc_output_dim = getattr(args, "enc_output_dim", 1024) diff --git a/fairseq/examples/speech_recognition/models/w2l_conv_glu_enc.py b/fairseq/examples/speech_recognition/models/w2l_conv_glu_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..655a9b0d19d11e35511392a016f9d6b7d7aa2925 --- /dev/null +++ b/fairseq/examples/speech_recognition/models/w2l_conv_glu_enc.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderModel, + register_model, + register_model_architecture, +) +from fairseq.modules.fairseq_dropout import FairseqDropout + + +default_conv_enc_config = """[ + (400, 13, 170, 0.2), + (440, 14, 0, 0.214), + (484, 15, 0, 0.22898), + (532, 16, 0, 0.2450086), + (584, 17, 0, 0.262159202), + (642, 18, 0, 0.28051034614), + (706, 19, 0, 0.30014607037), + (776, 20, 0, 0.321156295296), + (852, 21, 0, 0.343637235966), + (936, 22, 0, 0.367691842484), + (1028, 23, 0, 0.393430271458), + (1130, 24, 0, 0.42097039046), + (1242, 25, 0, 0.450438317792), + (1366, 26, 0, 0.481969000038), + (1502, 27, 0, 0.51570683004), + (1652, 28, 0, 0.551806308143), + (1816, 29, 0, 0.590432749713), +]""" + + +@register_model("asr_w2l_conv_glu_encoder") +class W2lConvGluEncoderModel(FairseqEncoderModel): + def __init__(self, encoder): + super().__init__(encoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--input-feat-per-channel", + type=int, + metavar="N", + help="encoder input dimension per input channel", + ) + parser.add_argument( + "--in-channels", + type=int, + metavar="N", + help="number of encoder input channels", + ) + parser.add_argument( + "--conv-enc-config", + type=str, + metavar="EXPR", + help=""" + an array of tuples each containing the configuration of one conv layer + [(out_channels, kernel_size, padding, dropout), ...] + """, + ) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + conv_enc_config = getattr(args, "conv_enc_config", default_conv_enc_config) + encoder = W2lConvGluEncoder( + vocab_size=len(task.target_dictionary), + input_feat_per_channel=args.input_feat_per_channel, + in_channels=args.in_channels, + conv_enc_config=eval(conv_enc_config), + ) + return cls(encoder) + + def get_normalized_probs(self, net_output, log_probs, sample=None): + lprobs = super().get_normalized_probs(net_output, log_probs, sample) + lprobs.batch_first = False + return lprobs + + +class W2lConvGluEncoder(FairseqEncoder): + def __init__( + self, vocab_size, input_feat_per_channel, in_channels, conv_enc_config + ): + super().__init__(None) + + self.input_dim = input_feat_per_channel + if in_channels != 1: + raise ValueError("only 1 input channel is currently supported") + + self.conv_layers = nn.ModuleList() + self.linear_layers = nn.ModuleList() + self.dropouts = [] + cur_channels = input_feat_per_channel + + for out_channels, kernel_size, padding, dropout in conv_enc_config: + layer = nn.Conv1d(cur_channels, out_channels, kernel_size, padding=padding) + layer.weight.data.mul_(math.sqrt(3)) # match wav2letter init + self.conv_layers.append(nn.utils.weight_norm(layer)) + self.dropouts.append( + FairseqDropout(dropout, module_name=self.__class__.__name__) + ) + if out_channels % 2 != 0: + raise ValueError("odd # of out_channels is incompatible with GLU") + cur_channels = out_channels // 2 # halved by GLU + + for out_channels in [2 * cur_channels, vocab_size]: + layer = nn.Linear(cur_channels, out_channels) + layer.weight.data.mul_(math.sqrt(3)) + self.linear_layers.append(nn.utils.weight_norm(layer)) + cur_channels = out_channels // 2 + + def forward(self, src_tokens, src_lengths, **kwargs): + + """ + src_tokens: padded tensor (B, T, C * feat) + src_lengths: tensor of original lengths of input utterances (B,) + """ + B, T, _ = src_tokens.size() + x = src_tokens.transpose(1, 2).contiguous() # (B, feat, T) assuming C == 1 + + for layer_idx in range(len(self.conv_layers)): + x = self.conv_layers[layer_idx](x) + x = F.glu(x, dim=1) + x = self.dropouts[layer_idx](x) + + x = x.transpose(1, 2).contiguous() # (B, T, 908) + x = self.linear_layers[0](x) + x = F.glu(x, dim=2) + x = self.dropouts[-1](x) + x = self.linear_layers[1](x) + + assert x.size(0) == B + assert x.size(1) == T + + encoder_out = x.transpose(0, 1) # (T, B, vocab_size) + + # need to debug this -- find a simpler/elegant way in pytorch APIs + encoder_padding_mask = ( + torch.arange(T).view(1, T).expand(B, -1).to(x.device) + >= src_lengths.view(B, 1).expand(-1, T) + ).t() # (B x T) -> (T x B) + + return { + "encoder_out": encoder_out, # (T, B, vocab_size) + "encoder_padding_mask": encoder_padding_mask, # (T, B) + } + + def reorder_encoder_out(self, encoder_out, new_order): + encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select( + 1, new_order + ) + encoder_out["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ].index_select(1, new_order) + return encoder_out + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return (1e6, 1e6) # an arbitrary large number + + +@register_model_architecture("asr_w2l_conv_glu_encoder", "w2l_conv_glu_enc") +def w2l_conv_glu_enc(args): + args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) + args.in_channels = getattr(args, "in_channels", 1) + args.conv_enc_config = getattr(args, "conv_enc_config", default_conv_enc_config) diff --git a/fairseq/examples/speech_recognition/new/README.md b/fairseq/examples/speech_recognition/new/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5fa0e97245d3ba6db69d11222261b0644960183d --- /dev/null +++ b/fairseq/examples/speech_recognition/new/README.md @@ -0,0 +1,43 @@ +# Flashlight Decoder + +This script runs decoding for pre-trained speech recognition models. + +## Usage + +Assuming a few variables: + +```bash +checkpoint= +data= +lm_model= +lexicon= +``` + +Example usage for decoding a fine-tuned Wav2Vec model: + +```bash +python $FAIRSEQ_ROOT/examples/speech_recognition/new/infer.py --multirun \ + task=audio_pretraining \ + task.data=$data \ + task.labels=ltr \ + common_eval.path=$checkpoint \ + decoding.type=kenlm \ + decoding.lexicon=$lexicon \ + decoding.lmpath=$lm_model \ + dataset.gen_subset=dev_clean,dev_other,test_clean,test_other +``` + +Example usage for using Ax to sweep WER parameters (requires `pip install hydra-ax-sweeper`): + +```bash +python $FAIRSEQ_ROOT/examples/speech_recognition/new/infer.py --multirun \ + hydra/sweeper=ax \ + task=audio_pretraining \ + task.data=$data \ + task.labels=ltr \ + common_eval.path=$checkpoint \ + decoding.type=kenlm \ + decoding.lexicon=$lexicon \ + decoding.lmpath=$lm_model \ + dataset.gen_subset=dev_other +``` diff --git a/fairseq/examples/speech_recognition/new/__init__.py b/fairseq/examples/speech_recognition/new/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fairseq/examples/speech_recognition/new/conf/hydra/sweeper/ax.yaml b/fairseq/examples/speech_recognition/new/conf/hydra/sweeper/ax.yaml new file mode 100644 index 0000000000000000000000000000000000000000..38e9c221db234877056f59e5d2881199a3d601b4 --- /dev/null +++ b/fairseq/examples/speech_recognition/new/conf/hydra/sweeper/ax.yaml @@ -0,0 +1,29 @@ +# @package hydra.sweeper +_target_: hydra_plugins.hydra_ax_sweeper.ax_sweeper.AxSweeper +max_batch_size: null +ax_config: + max_trials: 128 + early_stop: + minimize: true + max_epochs_without_improvement: 10 + epsilon: 0.025 + experiment: + name: ${dataset.gen_subset} + objective_name: wer + minimize: true + parameter_constraints: null + outcome_constraints: null + status_quo: null + client: + verbose_logging: false + random_seed: null + params: + decoding.lmweight: + type: range + bounds: [0.0, 5.0] + decoding.wordscore: + type: range + bounds: [-5.0, 5.0] + decoding.silweight: + type: range + bounds: [ -8.0, 0.0 ] diff --git a/fairseq/examples/speech_recognition/new/conf/hydra/sweeper/ax_sil.yaml b/fairseq/examples/speech_recognition/new/conf/hydra/sweeper/ax_sil.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eaaebcf5f606947d6dc520f611e749128c4a009c --- /dev/null +++ b/fairseq/examples/speech_recognition/new/conf/hydra/sweeper/ax_sil.yaml @@ -0,0 +1,29 @@ +# @package hydra.sweeper +_target_: hydra_plugins.hydra_ax_sweeper.ax_sweeper.AxSweeper +max_batch_size: null +ax_config: + max_trials: 64 + early_stop: + minimize: true + max_epochs_without_improvement: 10 + epsilon: 0.025 + experiment: + name: ${dataset.gen_subset} + objective_name: wer + minimize: true + parameter_constraints: null + outcome_constraints: null + status_quo: null + client: + verbose_logging: false + random_seed: null + params: + decoding.lmweight: + type: range + bounds: [0.0, 10.0] + decoding.wordscore: + type: range + bounds: [-10.0, 10.0] + decoding.silweight: + type: range + bounds: [ -10.0, 0.0 ] diff --git a/fairseq/examples/speech_recognition/new/conf/infer.yaml b/fairseq/examples/speech_recognition/new/conf/infer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2d168d06afbd95c1322c70d065f453a8e967477b --- /dev/null +++ b/fairseq/examples/speech_recognition/new/conf/infer.yaml @@ -0,0 +1,27 @@ +# @package _group_ + +defaults: + - task: null + - model: null + +hydra: + run: + dir: ${common_eval.results_path}/${dataset.gen_subset} + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${common_eval.results_path} + subdir: ${dataset.gen_subset} +common: + user_dir: /private/home/abaevski/fairseq-py/examples/data2vec +common_eval: + results_path: null + path: null + post_process: letter + quiet: true +dataset: + max_tokens: 3000000 + gen_subset: test +distributed_training: + distributed_world_size: 1 +decoding: + beam: 5 + type: viterbi diff --git a/fairseq/examples/speech_recognition/new/conf/run_config/fb_slurm_1.yaml b/fairseq/examples/speech_recognition/new/conf/run_config/fb_slurm_1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d0a9b0e5863e0ceae3f5b75049c5357fe101cbcc --- /dev/null +++ b/fairseq/examples/speech_recognition/new/conf/run_config/fb_slurm_1.yaml @@ -0,0 +1,28 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - common_eval.path + sweep: + dir: /checkpoint/abaevski/asr/d2v2/decoding/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} +# subdir: ${hydra.job.override_dirname} + launcher: + cpus_per_task: 16 + gpus_per_node: 1 + tasks_per_node: 1 + nodes: 1 + partition: devlab,learnlab + mem_gb: 100 + timeout_min: 2000 + max_num_timeout: 10 + name: ${env:PREFIX}_${hydra.job.config_name} + submitit_folder: ${hydra.sweep.dir}/%j + constraint: volta32gb + exclude: learnfair7598 \ No newline at end of file diff --git a/fairseq/examples/speech_recognition/new/conf/run_config/fb_slurm_2g.yaml b/fairseq/examples/speech_recognition/new/conf/run_config/fb_slurm_2g.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c0c442f76d06516de3f85675029a8d79e5c4ee69 --- /dev/null +++ b/fairseq/examples/speech_recognition/new/conf/run_config/fb_slurm_2g.yaml @@ -0,0 +1,27 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - common_eval.path + sweep: + dir: /checkpoint/abaevski/asr/d2v2/decoding/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} +# subdir: ${hydra.job.override_dirname} + launcher: + cpus_per_task: 16 + gpus_per_node: 2 + tasks_per_node: 2 + nodes: 1 + partition: devlab,learnlab + mem_gb: 100 + timeout_min: 2000 + max_num_timeout: 10 + name: ${env:PREFIX}_${hydra.job.config_name} + submitit_folder: ${hydra.sweep.dir}/%j + constraint: volta32gb \ No newline at end of file diff --git a/fairseq/examples/speech_recognition/new/decoders/__init__.py b/fairseq/examples/speech_recognition/new/decoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fairseq/examples/speech_recognition/new/decoders/base_decoder.py b/fairseq/examples/speech_recognition/new/decoders/base_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a097969b3c0650cf8ea2ab5f8e96bbc68ea9b97f --- /dev/null +++ b/fairseq/examples/speech_recognition/new/decoders/base_decoder.py @@ -0,0 +1,62 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools as it +from typing import Any, Dict, List + +import torch +from fairseq.data.dictionary import Dictionary +from fairseq.models.fairseq_model import FairseqModel + + +class BaseDecoder: + def __init__(self, tgt_dict: Dictionary) -> None: + self.tgt_dict = tgt_dict + self.vocab_size = len(tgt_dict) + + self.blank = ( + tgt_dict.index("") + if "" in tgt_dict.indices + else tgt_dict.bos() + ) + if "" in tgt_dict.indices: + self.silence = tgt_dict.index("") + elif "|" in tgt_dict.indices: + self.silence = tgt_dict.index("|") + else: + self.silence = tgt_dict.eos() + + def generate( + self, models: List[FairseqModel], sample: Dict[str, Any], **unused + ) -> List[List[Dict[str, torch.LongTensor]]]: + encoder_input = { + k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens" + } + emissions = self.get_emissions(models, encoder_input) + return self.decode(emissions) + + def get_emissions( + self, + models: List[FairseqModel], + encoder_input: Dict[str, Any], + ) -> torch.FloatTensor: + model = models[0] + encoder_out = model(**encoder_input) + if hasattr(model, "get_logits"): + emissions = model.get_logits(encoder_out) + else: + emissions = model.get_normalized_probs(encoder_out, log_probs=True) + return emissions.transpose(0, 1).float().cpu().contiguous() + + def get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor: + idxs = (g[0] for g in it.groupby(idxs)) + idxs = filter(lambda x: x != self.blank, idxs) + return torch.LongTensor(list(idxs)) + + def decode( + self, + emissions: torch.FloatTensor, + ) -> List[List[Dict[str, torch.LongTensor]]]: + raise NotImplementedError diff --git a/fairseq/examples/speech_recognition/new/decoders/decoder.py b/fairseq/examples/speech_recognition/new/decoders/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b5bec8cf707b53104ef7a45993a5db2893d3443b --- /dev/null +++ b/fairseq/examples/speech_recognition/new/decoders/decoder.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Union + +from fairseq.data.dictionary import Dictionary + +from .decoder_config import DecoderConfig, FlashlightDecoderConfig +from .base_decoder import BaseDecoder + + +def Decoder( + cfg: Union[DecoderConfig, FlashlightDecoderConfig], tgt_dict: Dictionary +) -> BaseDecoder: + + if cfg.type == "viterbi": + from .viterbi_decoder import ViterbiDecoder + + return ViterbiDecoder(tgt_dict) + if cfg.type == "kenlm": + from .flashlight_decoder import KenLMDecoder + + return KenLMDecoder(cfg, tgt_dict) + if cfg.type == "fairseqlm": + from .flashlight_decoder import FairseqLMDecoder + + return FairseqLMDecoder(cfg, tgt_dict) + raise NotImplementedError(f"Invalid decoder name: {cfg.name}") diff --git a/fairseq/examples/speech_recognition/new/decoders/decoder_config.py b/fairseq/examples/speech_recognition/new/decoders/decoder_config.py new file mode 100644 index 0000000000000000000000000000000000000000..659eb94a9b8187a7c126d7b439ac2742f9d72022 --- /dev/null +++ b/fairseq/examples/speech_recognition/new/decoders/decoder_config.py @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass, field +from typing import Optional + +from fairseq.dataclass.configs import FairseqDataclass +from fairseq.dataclass.constants import ChoiceEnum +from omegaconf import MISSING + + +DECODER_CHOICES = ChoiceEnum(["viterbi", "kenlm", "fairseqlm"]) + + +@dataclass +class DecoderConfig(FairseqDataclass): + type: DECODER_CHOICES = field( + default="viterbi", + metadata={"help": "The type of decoder to use"}, + ) + + +@dataclass +class FlashlightDecoderConfig(FairseqDataclass): + nbest: int = field( + default=1, + metadata={"help": "Number of decodings to return"}, + ) + unitlm: bool = field( + default=False, + metadata={"help": "If set, use unit language model"}, + ) + lmpath: str = field( + default=MISSING, + metadata={"help": "Language model for KenLM decoder"}, + ) + lexicon: Optional[str] = field( + default=None, + metadata={"help": "Lexicon for Flashlight decoder"}, + ) + beam: int = field( + default=50, + metadata={"help": "Number of beams to use for decoding"}, + ) + beamthreshold: float = field( + default=50.0, + metadata={"help": "Threshold for beam search decoding"}, + ) + beamsizetoken: Optional[int] = field( + default=None, metadata={"help": "Beam size to use"} + ) + wordscore: float = field( + default=-1, + metadata={"help": "Word score for KenLM decoder"}, + ) + unkweight: float = field( + default=-math.inf, + metadata={"help": "Unknown weight for KenLM decoder"}, + ) + silweight: float = field( + default=0, + metadata={"help": "Silence weight for KenLM decoder"}, + ) + lmweight: float = field( + default=2, + metadata={"help": "Weight for LM while interpolating score"}, + ) diff --git a/fairseq/examples/speech_recognition/new/decoders/flashlight_decoder.py b/fairseq/examples/speech_recognition/new/decoders/flashlight_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..7790fcdb8202101d27bf6798554ed7bc25f5c1a6 --- /dev/null +++ b/fairseq/examples/speech_recognition/new/decoders/flashlight_decoder.py @@ -0,0 +1,433 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import gc +import os.path as osp +import warnings +from collections import deque, namedtuple +from typing import Any, Dict, Tuple + +import numpy as np +import torch +from fairseq import tasks +from fairseq.data.dictionary import Dictionary +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.models.fairseq_model import FairseqModel +from fairseq.utils import apply_to_sample +from omegaconf import open_dict, OmegaConf + +from typing import List + +from .decoder_config import FlashlightDecoderConfig +from .base_decoder import BaseDecoder + +try: + from flashlight.lib.text.decoder import ( + LM, + CriterionType, + DecodeResult, + KenLM, + LexiconDecoder, + LexiconDecoderOptions, + LexiconFreeDecoder, + LexiconFreeDecoderOptions, + LMState, + SmearingMode, + Trie, + ) + from flashlight.lib.text.dictionary import create_word_dict, load_words + from flashlight.lib.text.dictionary import Dictionary as flDictionary +except ImportError: + warnings.warn( + "flashlight python bindings are required to use this functionality. " + "Please install from " + "https://github.com/facebookresearch/flashlight/tree/master/bindings/python" + ) + LM = object + LMState = object + + +class KenLMDecoder(BaseDecoder): + def __init__(self, cfg: FlashlightDecoderConfig, tgt_dict: Dictionary) -> None: + super().__init__(tgt_dict) + + self.nbest = cfg.nbest + self.unitlm = cfg.unitlm + + if cfg.lexicon: + self.lexicon = load_words(cfg.lexicon) + self.word_dict = create_word_dict(self.lexicon) + self.unk_word = self.word_dict.get_index("") + + self.lm = KenLM(cfg.lmpath, self.word_dict) + self.trie = Trie(self.vocab_size, self.silence) + + start_state = self.lm.start(False) + for word, spellings in self.lexicon.items(): + word_idx = self.word_dict.get_index(word) + _, score = self.lm.score(start_state, word_idx) + for spelling in spellings: + spelling_idxs = [tgt_dict.index(token) for token in spelling] + assert ( + tgt_dict.unk() not in spelling_idxs + ), f"{word} {spelling} {spelling_idxs}" + self.trie.insert(spelling_idxs, word_idx, score) + self.trie.smear(SmearingMode.MAX) + + self.decoder_opts = LexiconDecoderOptions( + beam_size=cfg.beam, + beam_size_token=cfg.beamsizetoken or len(tgt_dict), + beam_threshold=cfg.beamthreshold, + lm_weight=cfg.lmweight, + word_score=cfg.wordscore, + unk_score=cfg.unkweight, + sil_score=cfg.silweight, + log_add=False, + criterion_type=CriterionType.CTC, + ) + + self.decoder = LexiconDecoder( + self.decoder_opts, + self.trie, + self.lm, + self.silence, + self.blank, + self.unk_word, + [], + self.unitlm, + ) + else: + assert self.unitlm, "Lexicon-free decoding requires unit LM" + + self.word_dict = flDictionary() + for sym in tgt_dict.symbols: + self.word_dict.add_entry(sym, tgt_dict.index(sym)) + self.lm = KenLM(cfg.lmpath, self.word_dict) + self.decoder_opts = LexiconFreeDecoderOptions( + beam_size=cfg.beam, + beam_size_token=cfg.beamsizetoken or len(tgt_dict), + beam_threshold=cfg.beamthreshold, + lm_weight=cfg.lmweight, + sil_score=cfg.silweight, + log_add=False, + criterion_type=CriterionType.CTC, + ) + self.decoder = LexiconFreeDecoder( + self.decoder_opts, self.lm, self.silence, self.blank, [] + ) + + def get_timesteps(self, token_idxs: List[int]) -> List[int]: + """Returns frame numbers corresponding to every non-blank token. + + Parameters + ---------- + token_idxs : List[int] + IDs of decoded tokens. + + Returns + ------- + List[int] + Frame numbers corresponding to every non-blank token. + """ + timesteps = [] + for i, token_idx in enumerate(token_idxs): + if token_idx == self.blank: + continue + if i == 0 or token_idx != token_idxs[i-1]: + timesteps.append(i) + return timesteps + + def decode( + self, + emissions: torch.FloatTensor, + ) -> List[List[Dict[str, torch.LongTensor]]]: + B, T, N = emissions.size() + hypos = [] + for b in range(B): + emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) + results = self.decoder.decode(emissions_ptr, T, N) + + nbest_results = results[: self.nbest] + hypos.append( + [ + { + "tokens": self.get_tokens(result.tokens), + "score": result.score, + "timesteps": self.get_timesteps(result.tokens), + "words": [ + self.word_dict.get_entry(x) for x in result.words if x >= 0 + ], + } + for result in nbest_results + ] + ) + return hypos + + +FairseqLMState = namedtuple( + "FairseqLMState", + [ + "prefix", + "incremental_state", + "probs", + ], +) + + +class FairseqLM(LM): + def __init__(self, dictionary: Dictionary, model: FairseqModel) -> None: + super().__init__() + + self.dictionary = dictionary + self.model = model + self.unk = self.dictionary.unk() + + self.save_incremental = False # this currently does not work properly + self.max_cache = 20_000 + + if torch.cuda.is_available(): + model.cuda() + model.eval() + model.make_generation_fast_() + + self.states = {} + self.stateq = deque() + + def start(self, start_with_nothing: bool) -> LMState: + state = LMState() + prefix = torch.LongTensor([[self.dictionary.eos()]]) + incremental_state = {} if self.save_incremental else None + with torch.no_grad(): + res = self.model(prefix.cuda(), incremental_state=incremental_state) + probs = self.model.get_normalized_probs(res, log_probs=True, sample=None) + + if incremental_state is not None: + incremental_state = apply_to_sample(lambda x: x.cpu(), incremental_state) + self.states[state] = FairseqLMState( + prefix.numpy(), incremental_state, probs[0, -1].cpu().numpy() + ) + self.stateq.append(state) + + return state + + def score( + self, + state: LMState, + token_index: int, + no_cache: bool = False, + ) -> Tuple[LMState, int]: + """ + Evaluate language model based on the current lm state and new word + Parameters: + ----------- + state: current lm state + token_index: index of the word + (can be lexicon index then you should store inside LM the + mapping between indices of lexicon and lm, or lm index of a word) + Returns: + -------- + (LMState, float): pair of (new state, score for the current word) + """ + curr_state = self.states[state] + + def trim_cache(targ_size: int) -> None: + while len(self.stateq) > targ_size: + rem_k = self.stateq.popleft() + rem_st = self.states[rem_k] + rem_st = FairseqLMState(rem_st.prefix, None, None) + self.states[rem_k] = rem_st + + if curr_state.probs is None: + new_incremental_state = ( + curr_state.incremental_state.copy() + if curr_state.incremental_state is not None + else None + ) + with torch.no_grad(): + if new_incremental_state is not None: + new_incremental_state = apply_to_sample( + lambda x: x.cuda(), new_incremental_state + ) + elif self.save_incremental: + new_incremental_state = {} + + res = self.model( + torch.from_numpy(curr_state.prefix).cuda(), + incremental_state=new_incremental_state, + ) + probs = self.model.get_normalized_probs( + res, log_probs=True, sample=None + ) + + if new_incremental_state is not None: + new_incremental_state = apply_to_sample( + lambda x: x.cpu(), new_incremental_state + ) + + curr_state = FairseqLMState( + curr_state.prefix, new_incremental_state, probs[0, -1].cpu().numpy() + ) + + if not no_cache: + self.states[state] = curr_state + self.stateq.append(state) + + score = curr_state.probs[token_index].item() + + trim_cache(self.max_cache) + + outstate = state.child(token_index) + if outstate not in self.states and not no_cache: + prefix = np.concatenate( + [curr_state.prefix, torch.LongTensor([[token_index]])], -1 + ) + incr_state = curr_state.incremental_state + + self.states[outstate] = FairseqLMState(prefix, incr_state, None) + + if token_index == self.unk: + score = float("-inf") + + return outstate, score + + def finish(self, state: LMState) -> Tuple[LMState, int]: + """ + Evaluate eos for language model based on the current lm state + Returns: + -------- + (LMState, float): pair of (new state, score for the current word) + """ + return self.score(state, self.dictionary.eos()) + + def empty_cache(self) -> None: + self.states = {} + self.stateq = deque() + gc.collect() + + +class FairseqLMDecoder(BaseDecoder): + def __init__(self, cfg: FlashlightDecoderConfig, tgt_dict: Dictionary) -> None: + super().__init__(tgt_dict) + + self.nbest = cfg.nbest + self.unitlm = cfg.unitlm + + self.lexicon = load_words(cfg.lexicon) if cfg.lexicon else None + self.idx_to_wrd = {} + + checkpoint = torch.load(cfg.lmpath, map_location="cpu") + + if "cfg" in checkpoint and checkpoint["cfg"] is not None: + lm_args = checkpoint["cfg"] + else: + lm_args = convert_namespace_to_omegaconf(checkpoint["args"]) + + if not OmegaConf.is_dict(lm_args): + lm_args = OmegaConf.create(lm_args) + + with open_dict(lm_args.task): + lm_args.task.data = osp.dirname(cfg.lmpath) + + task = tasks.setup_task(lm_args.task) + model = task.build_model(lm_args.model) + model.load_state_dict(checkpoint["model"], strict=False) + + self.trie = Trie(self.vocab_size, self.silence) + + self.word_dict = task.dictionary + self.unk_word = self.word_dict.unk() + self.lm = FairseqLM(self.word_dict, model) + + if self.lexicon: + start_state = self.lm.start(False) + for i, (word, spellings) in enumerate(self.lexicon.items()): + if self.unitlm: + word_idx = i + self.idx_to_wrd[i] = word + score = 0 + else: + word_idx = self.word_dict.index(word) + _, score = self.lm.score(start_state, word_idx, no_cache=True) + + for spelling in spellings: + spelling_idxs = [tgt_dict.index(token) for token in spelling] + assert ( + tgt_dict.unk() not in spelling_idxs + ), f"{spelling} {spelling_idxs}" + self.trie.insert(spelling_idxs, word_idx, score) + self.trie.smear(SmearingMode.MAX) + + self.decoder_opts = LexiconDecoderOptions( + beam_size=cfg.beam, + beam_size_token=cfg.beamsizetoken or len(tgt_dict), + beam_threshold=cfg.beamthreshold, + lm_weight=cfg.lmweight, + word_score=cfg.wordscore, + unk_score=cfg.unkweight, + sil_score=cfg.silweight, + log_add=False, + criterion_type=CriterionType.CTC, + ) + + self.decoder = LexiconDecoder( + self.decoder_opts, + self.trie, + self.lm, + self.silence, + self.blank, + self.unk_word, + [], + self.unitlm, + ) + else: + assert self.unitlm, "Lexicon-free decoding requires unit LM" + + d = {w: [[w]] for w in tgt_dict.symbols} + self.word_dict = create_word_dict(d) + self.lm = KenLM(cfg.lmpath, self.word_dict) + self.decoder_opts = LexiconFreeDecoderOptions( + beam_size=cfg.beam, + beam_size_token=cfg.beamsizetoken or len(tgt_dict), + beam_threshold=cfg.beamthreshold, + lm_weight=cfg.lmweight, + sil_score=cfg.silweight, + log_add=False, + criterion_type=CriterionType.CTC, + ) + self.decoder = LexiconFreeDecoder( + self.decoder_opts, self.lm, self.silence, self.blank, [] + ) + + def decode( + self, + emissions: torch.FloatTensor, + ) -> List[List[Dict[str, torch.LongTensor]]]: + B, T, N = emissions.size() + hypos = [] + + def make_hypo(result: DecodeResult) -> Dict[str, Any]: + hypo = { + "tokens": self.get_tokens(result.tokens), + "score": result.score, + } + if self.lexicon: + hypo["words"] = [ + self.idx_to_wrd[x] if self.unitlm else self.word_dict[x] + for x in result.words + if x >= 0 + ] + return hypo + + for b in range(B): + emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) + results = self.decoder.decode(emissions_ptr, T, N) + + nbest_results = results[: self.nbest] + hypos.append([make_hypo(result) for result in nbest_results]) + self.lm.empty_cache() + + return hypos diff --git a/fairseq/examples/speech_recognition/new/decoders/viterbi_decoder.py b/fairseq/examples/speech_recognition/new/decoders/viterbi_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a35d95e1464291ed27b7c0b46c4936b5a9184503 --- /dev/null +++ b/fairseq/examples/speech_recognition/new/decoders/viterbi_decoder.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from typing import List, Dict + +from .base_decoder import BaseDecoder + + +class ViterbiDecoder(BaseDecoder): + def decode( + self, + emissions: torch.FloatTensor, + ) -> List[List[Dict[str, torch.LongTensor]]]: + def get_pred(e): + score = e.log_softmax(dim=-1).max(dim=-1)[0].sum() + toks = e.argmax(dim=-1).unique_consecutive() + return {"tokens":toks[toks != self.blank], "score":score} + return [[get_pred(x)] for x in emissions] diff --git a/fairseq/examples/speech_recognition/new/infer.py b/fairseq/examples/speech_recognition/new/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5cea4a7c20dbc3eb3787b2f44e3e3817ab8f27 --- /dev/null +++ b/fairseq/examples/speech_recognition/new/infer.py @@ -0,0 +1,502 @@ +#!/usr/bin/env python -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import ast +import hashlib +import logging +import os +import shutil +import sys +import re +from dataclasses import dataclass, field, is_dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import editdistance +import torch +import torch.distributed as dist +from examples.speech_recognition.new.decoders.decoder_config import ( + DecoderConfig, + FlashlightDecoderConfig, +) +from examples.speech_recognition.new.decoders.decoder import Decoder +from fairseq import checkpoint_utils, distributed_utils, progress_bar, tasks, utils +from fairseq.data.data_utils import post_process +from fairseq.dataclass.configs import ( + CheckpointConfig, + CommonConfig, + CommonEvalConfig, + DatasetConfig, + DistributedTrainingConfig, + FairseqDataclass, +) +from fairseq.logging.meters import StopwatchMeter, TimeMeter +from fairseq.logging.progress_bar import BaseProgressBar +from fairseq.models.fairseq_model import FairseqModel +from omegaconf import OmegaConf + +import hydra +from hydra.core.config_store import ConfigStore + +logging.root.setLevel(logging.INFO) +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +config_path = Path(__file__).resolve().parent / "conf" + + +@dataclass +class DecodingConfig(DecoderConfig, FlashlightDecoderConfig): + unique_wer_file: bool = field( + default=False, + metadata={"help": "If set, use a unique file for storing WER"}, + ) + results_path: Optional[str] = field( + default=None, + metadata={ + "help": "If set, write hypothesis and reference sentences into this directory" + }, + ) + + +@dataclass +class InferConfig(FairseqDataclass): + task: Any = None + decoding: DecodingConfig = DecodingConfig() + common: CommonConfig = CommonConfig() + common_eval: CommonEvalConfig = CommonEvalConfig() + checkpoint: CheckpointConfig = CheckpointConfig() + distributed_training: DistributedTrainingConfig = DistributedTrainingConfig() + dataset: DatasetConfig = DatasetConfig() + is_ax: bool = field( + default=False, + metadata={ + "help": "if true, assumes we are using ax for tuning and returns a tuple for ax to consume" + }, + ) + + +def reset_logging(): + root = logging.getLogger() + for handler in root.handlers: + root.removeHandler(handler) + root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper()) + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + ) + root.addHandler(handler) + + +class InferenceProcessor: + cfg: InferConfig + + def __init__(self, cfg: InferConfig) -> None: + self.cfg = cfg + self.task = tasks.setup_task(cfg.task) + + models, saved_cfg = self.load_model_ensemble() + + ### LOAD ADAPTER #### + ckpt_obj = checkpoint_utils.load_checkpoint_to_cpu(self.cfg.common_eval.path) + if "adapter" in ckpt_obj: + target_lang = self.cfg.dataset.gen_subset.split(":")[0] + assert target_lang in ckpt_obj["adapter"] + + logger.info(f">>> LOADING ADAPTER: {target_lang}") + ft_obj = ckpt_obj["adapter"][target_lang] + ft_model = ft_obj["model"] + cdevice = models[0].w2v_encoder.proj.weight.device + cdtype = models[0].w2v_encoder.proj.weight.dtype + ft_proj_out, ft_proj_in = ft_model["w2v_encoder.proj.weight"].shape + ft_proj = torch.nn.Linear(ft_proj_in, ft_proj_out, bias=True) + ft_proj.to(device=cdevice, dtype=cdtype) + models[0].w2v_encoder.proj = ft_proj + with torch.no_grad(): + for kk, vv in models[0].named_parameters(): + if kk in ft_model: + vv.copy_(ft_model[kk]) + self.task.load_state_dict(ft_obj["task_state"]) + # overwrite gen_subset with master config + self.cfg.dataset.gen_subset = re.sub('^[\w-]+:', saved_cfg['task']['multi_corpus_keys']+":", self.cfg.dataset.gen_subset) + self.models = models + self.saved_cfg = saved_cfg + self.tgt_dict = self.task.target_dictionary + + self.task.load_dataset( + self.cfg.dataset.gen_subset, + task_cfg=saved_cfg.task, + ) + self.generator = Decoder(cfg.decoding, self.tgt_dict) + self.gen_timer = StopwatchMeter() + self.wps_meter = TimeMeter() + self.num_sentences = 0 + self.total_errors = 0 + self.total_length = 0 + + self.hypo_words_file = None + self.hypo_units_file = None + self.ref_words_file = None + self.ref_units_file = None + self.score_file = None + + self.progress_bar = self.build_progress_bar() + + def __enter__(self) -> "InferenceProcessor": + if self.cfg.decoding.results_path is not None: + self.hypo_words_file = self.get_res_file("hypo.word") + self.hypo_units_file = self.get_res_file("hypo.units") + self.ref_words_file = self.get_res_file("ref.word") + self.ref_units_file = self.get_res_file("ref.units") + self.score_file = self.get_res_file("asr_score") + return self + + def __exit__(self, *exc) -> bool: + if self.cfg.decoding.results_path is not None: + self.hypo_words_file.close() + self.hypo_units_file.close() + self.ref_words_file.close() + self.ref_units_file.close() + self.score_file.close() + return False + + def __iter__(self) -> Any: + for sample in self.progress_bar: + if not self.cfg.common.cpu: + sample = utils.move_to_cuda(sample) + + # Happens on the last batch. + if "net_input" not in sample: + continue + yield sample + + def log(self, *args, **kwargs): + self.progress_bar.log(*args, **kwargs) + + def print(self, *args, **kwargs): + self.progress_bar.print(*args, **kwargs) + + def get_res_file(self, fname: str) -> None: + fname = os.path.join(self.cfg.decoding.results_path, fname) + if self.data_parallel_world_size > 1: + fname = f"{fname}.{self.data_parallel_rank}" + return open(fname, "w", buffering=1) + + def merge_shards(self) -> None: + """Merges all shard files into shard 0, then removes shard suffix.""" + + shard_id = self.data_parallel_rank + num_shards = self.data_parallel_world_size + + if self.data_parallel_world_size > 1: + + def merge_shards_with_root(fname: str) -> None: + fname = os.path.join(self.cfg.decoding.results_path, fname) + logger.info("Merging %s on shard %d", fname, shard_id) + base_fpath = Path(f"{fname}.0") + with open(base_fpath, "a") as out_file: + for s in range(1, num_shards): + shard_fpath = Path(f"{fname}.{s}") + with open(shard_fpath, "r") as in_file: + for line in in_file: + out_file.write(line) + shard_fpath.unlink() + shutil.move(f"{fname}.0", fname) + + dist.barrier() # ensure all shards finished writing + if shard_id == (0 % num_shards): + merge_shards_with_root("hypo.word") + if shard_id == (1 % num_shards): + merge_shards_with_root("hypo.units") + if shard_id == (2 % num_shards): + merge_shards_with_root("ref.word") + if shard_id == (3 % num_shards): + merge_shards_with_root("ref.units") + dist.barrier() + + def optimize_model(self, model: FairseqModel) -> None: + model.make_generation_fast_() + if self.cfg.common.fp16: + model.half() + if not self.cfg.common.cpu: + model.cuda() + + def load_model_ensemble(self) -> Tuple[List[FairseqModel], FairseqDataclass]: + arg_overrides = ast.literal_eval(self.cfg.common_eval.model_overrides) + models, saved_cfg = checkpoint_utils.load_model_ensemble( + utils.split_paths(self.cfg.common_eval.path, separator="\\"), + arg_overrides=arg_overrides, + task=self.task, + suffix=self.cfg.checkpoint.checkpoint_suffix, + strict=(self.cfg.checkpoint.checkpoint_shard_count == 1), + num_shards=self.cfg.checkpoint.checkpoint_shard_count, + ) + for model in models: + self.optimize_model(model) + return models, saved_cfg + + def get_dataset_itr(self, disable_iterator_cache: bool = False) -> None: + return self.task.get_batch_iterator( + dataset=self.task.dataset(self.cfg.dataset.gen_subset), + max_tokens=self.cfg.dataset.max_tokens, + max_sentences=self.cfg.dataset.batch_size, + max_positions=(sys.maxsize, sys.maxsize), + ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, + seed=self.cfg.common.seed, + num_shards=self.data_parallel_world_size, + shard_id=self.data_parallel_rank, + num_workers=self.cfg.dataset.num_workers, + data_buffer_size=self.cfg.dataset.data_buffer_size, + disable_iterator_cache=disable_iterator_cache, + ).next_epoch_itr(shuffle=False) + + def build_progress_bar( + self, + epoch: Optional[int] = None, + prefix: Optional[str] = None, + default_log_format: str = "tqdm", + ) -> BaseProgressBar: + return progress_bar.progress_bar( + iterator=self.get_dataset_itr(), + log_format=self.cfg.common.log_format, + log_interval=self.cfg.common.log_interval, + epoch=epoch, + prefix=prefix, + tensorboard_logdir=self.cfg.common.tensorboard_logdir, + default_log_format=default_log_format, + ) + + @property + def data_parallel_world_size(self): + if self.cfg.distributed_training.distributed_world_size == 1: + return 1 + return distributed_utils.get_data_parallel_world_size() + + @property + def data_parallel_rank(self): + if self.cfg.distributed_training.distributed_world_size == 1: + return 0 + return distributed_utils.get_data_parallel_rank() + + def process_sentence( + self, + sample: Dict[str, Any], + hypo: Dict[str, Any], + sid: int, + batch_id: int, + ) -> Tuple[int, int]: + speaker = None # Speaker can't be parsed from dataset. + if "target_label" in sample: + toks = sample["target_label"] + else: + toks = sample["target"] + toks = toks[batch_id, :] + + # Processes hypothesis. + hyp_pieces = self.tgt_dict.string(hypo["tokens"].int().cpu()) + if "words" in hypo: + hyp_words = " ".join(hypo["words"]) + else: + hyp_words = post_process(hyp_pieces, self.cfg.common_eval.post_process) + + # Processes target. + target_tokens = utils.strip_pad(toks, self.tgt_dict.pad()) + tgt_pieces = self.tgt_dict.string(target_tokens.int().cpu()) + tgt_words = post_process(tgt_pieces, self.cfg.common_eval.post_process) + + if self.cfg.decoding.results_path is not None: + print(f"{hyp_pieces} ({speaker}-{sid})", file=self.hypo_units_file) + print(f"{hyp_words} ({speaker}-{sid})", file=self.hypo_words_file) + print(f"{tgt_pieces} ({speaker}-{sid})", file=self.ref_units_file) + print(f"{tgt_words} ({speaker}-{sid})", file=self.ref_words_file) + print(f"{hypo['score'].item()} ({speaker}-{sid})", file=self.score_file) + + if not self.cfg.common_eval.quiet: + logger.info(f"HYPO: {hyp_words}") + logger.info(f"REF: {tgt_words}") + logger.info("---------------------") + + hyp_words, tgt_words = hyp_words.split(), tgt_words.split() + + return editdistance.eval(hyp_words, tgt_words), len(tgt_words) + + def process_sample(self, sample: Dict[str, Any]) -> None: + self.gen_timer.start() + hypos = self.task.inference_step( + generator=self.generator, + models=self.models, + sample=sample, + ) + num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) + self.gen_timer.stop(num_generated_tokens) + self.wps_meter.update(num_generated_tokens) + + for batch_id, sample_id in enumerate(sample["id"].tolist()): + errs, length = self.process_sentence( + sample=sample, + sid=sample_id, + batch_id=batch_id, + hypo=hypos[batch_id][0], + ) + self.total_errors += errs + self.total_length += length + + self.log({"wps": round(self.wps_meter.avg)}) + if "nsentences" in sample: + self.num_sentences += sample["nsentences"] + else: + self.num_sentences += sample["id"].numel() + + def log_generation_time(self) -> None: + logger.info( + "Processed %d sentences (%d tokens) in %.1fs %.2f " + "sentences per second, %.2f tokens per second)", + self.num_sentences, + self.gen_timer.n, + self.gen_timer.sum, + self.num_sentences / (self.gen_timer.sum + 1e-6), + 1.0 / (self.gen_timer.avg + 1e-6), + ) + + +def parse_wer(wer_file: Path) -> float: + with open(wer_file, "r") as f: + return float(f.readline().strip().split(" ")[1]) + + +def get_wer_file(cfg: InferConfig) -> Path: + """Hashes the decoding parameters to a unique file ID.""" + base_path = "wer" + if cfg.decoding.results_path is not None: + base_path = os.path.join(cfg.decoding.results_path, base_path) + + if cfg.decoding.unique_wer_file: + yaml_str = OmegaConf.to_yaml(cfg.decoding) + fid = int(hashlib.md5(yaml_str.encode("utf-8")).hexdigest(), 16) + return Path(f"{base_path}.{fid % 1000000}") + else: + return Path(base_path) + + +def main(cfg: InferConfig) -> float: + """Entry point for main processing logic. + + Args: + cfg: The inferance configuration to use. + wer: Optional shared memory pointer for returning the WER. If not None, + the final WER value will be written here instead of being returned. + + Returns: + The final WER if `wer` is None, otherwise None. + """ + + yaml_str, wer_file = OmegaConf.to_yaml(cfg.decoding), get_wer_file(cfg) + + # Validates the provided configuration. + if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: + cfg.dataset.max_tokens = 4000000 + if not cfg.common.cpu and not torch.cuda.is_available(): + raise ValueError("CUDA not found; set `cpu=True` to run without CUDA") + + logger.info(cfg.common_eval.path) + + with InferenceProcessor(cfg) as processor: + for sample in processor: + processor.process_sample(sample) + + processor.log_generation_time() + + if cfg.decoding.results_path is not None: + processor.merge_shards() + + errs_t, leng_t = processor.total_errors, processor.total_length + + if cfg.common.cpu: + logger.warning("Merging WER requires CUDA.") + elif processor.data_parallel_world_size > 1: + stats = torch.LongTensor([errs_t, leng_t]).cuda() + dist.all_reduce(stats, op=dist.ReduceOp.SUM) + errs_t, leng_t = stats[0].item(), stats[1].item() + + wer = errs_t * 100.0 / leng_t + + if distributed_utils.is_master(cfg.distributed_training): + with open(wer_file, "w") as f: + f.write( + ( + f"WER: {wer}\n" + f"err / num_ref_words = {errs_t} / {leng_t}\n\n" + f"{yaml_str}" + ) + ) + + return wer + + +@hydra.main(config_path=config_path, config_name="infer") +def hydra_main(cfg: InferConfig) -> Union[float, Tuple[float, Optional[float]]]: + container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True) + cfg = OmegaConf.create(container) + OmegaConf.set_struct(cfg, True) + + if cfg.common.reset_logging: + reset_logging() + + utils.import_user_module(cfg.common) + + # logger.info("Config:\n%s", OmegaConf.to_yaml(cfg)) + wer = float("inf") + + try: + if cfg.common.profile: + with torch.cuda.profiler.profile(): + with torch.autograd.profiler.emit_nvtx(): + distributed_utils.call_main(cfg, main) + else: + distributed_utils.call_main(cfg, main) + + wer = parse_wer(get_wer_file(cfg)) + except BaseException as e: # pylint: disable=broad-except + if not cfg.common.suppress_crashes: + raise + else: + logger.error("Crashed! %s", str(e)) + + logger.info("Word error rate: %.4f", wer) + if cfg.is_ax: + return wer, None + + return wer + + +def cli_main() -> None: + try: + from hydra._internal.utils import ( + get_args, + ) # pylint: disable=import-outside-toplevel + + cfg_name = get_args().config_name or "infer" + except ImportError: + logger.warning("Failed to get config name from hydra args") + cfg_name = "infer" + + cs = ConfigStore.instance() + cs.store(name=cfg_name, node=InferConfig) + + for k in InferConfig.__dataclass_fields__: + if is_dataclass(InferConfig.__dataclass_fields__[k].type): + v = InferConfig.__dataclass_fields__[k].default + cs.store(name=k, node=v) + + hydra_main() # pylint: disable=no-value-for-parameter + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq/examples/speech_recognition/tasks/__init__.py b/fairseq/examples/speech_recognition/tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7ac3b8dc69639c92cc129294356e9012745e3fb2 --- /dev/null +++ b/fairseq/examples/speech_recognition/tasks/__init__.py @@ -0,0 +1,8 @@ +import importlib +import os + + +for file in sorted(os.listdir(os.path.dirname(__file__))): + if file.endswith(".py") and not file.startswith("_"): + task_name = file[: file.find(".py")] + importlib.import_module("examples.speech_recognition.tasks." + task_name) diff --git a/fairseq/examples/speech_recognition/tasks/speech_recognition.py b/fairseq/examples/speech_recognition/tasks/speech_recognition.py new file mode 100644 index 0000000000000000000000000000000000000000..d9f011d55ff4fdfeb4c04ca790c314d685708c3a --- /dev/null +++ b/fairseq/examples/speech_recognition/tasks/speech_recognition.py @@ -0,0 +1,157 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import re +import sys + +import torch +from examples.speech_recognition.data import AsrDataset +from examples.speech_recognition.data.replabels import replabel_symbol +from fairseq.data import Dictionary +from fairseq.tasks import LegacyFairseqTask, register_task + + +def get_asr_dataset_from_json(data_json_path, tgt_dict): + """ + Parse data json and create dataset. + See scripts/asr_prep_json.py which pack json from raw files + + Json example: + { + "utts": { + "4771-29403-0025": { + "input": { + "length_ms": 170, + "path": "/tmp/file1.flac" + }, + "output": { + "text": "HELLO \n", + "token": "HE LLO", + "tokenid": "4815, 861" + } + }, + "1564-142299-0096": { + ... + } + } + """ + if not os.path.isfile(data_json_path): + raise FileNotFoundError("Dataset not found: {}".format(data_json_path)) + with open(data_json_path, "rb") as f: + data_samples = json.load(f)["utts"] + assert len(data_samples) != 0 + sorted_samples = sorted( + data_samples.items(), + key=lambda sample: int(sample[1]["input"]["length_ms"]), + reverse=True, + ) + aud_paths = [s[1]["input"]["path"] for s in sorted_samples] + ids = [s[0] for s in sorted_samples] + speakers = [] + for s in sorted_samples: + m = re.search("(.+?)-(.+?)-(.+?)", s[0]) + speakers.append(m.group(1) + "_" + m.group(2)) + frame_sizes = [s[1]["input"]["length_ms"] for s in sorted_samples] + tgt = [ + [int(i) for i in s[1]["output"]["tokenid"].split(", ")] + for s in sorted_samples + ] + # append eos + tgt = [[*t, tgt_dict.eos()] for t in tgt] + return AsrDataset(aud_paths, frame_sizes, tgt, tgt_dict, ids, speakers) + + +@register_task("speech_recognition") +class SpeechRecognitionTask(LegacyFairseqTask): + """ + Task for training speech recognition model. + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument("data", help="path to data directory") + parser.add_argument( + "--silence-token", default="\u2581", help="token for silence (used by w2l)" + ) + parser.add_argument( + "--max-source-positions", + default=sys.maxsize, + type=int, + metavar="N", + help="max number of frames in the source sequence", + ) + parser.add_argument( + "--max-target-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the target sequence", + ) + + def __init__(self, args, tgt_dict): + super().__init__(args) + self.tgt_dict = tgt_dict + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task (e.g., load dictionaries).""" + dict_path = os.path.join(args.data, "dict.txt") + if not os.path.isfile(dict_path): + raise FileNotFoundError("Dict not found: {}".format(dict_path)) + tgt_dict = Dictionary.load(dict_path) + + if args.criterion == "ctc_loss": + tgt_dict.add_symbol("") + elif args.criterion == "asg_loss": + for i in range(1, args.max_replabel + 1): + tgt_dict.add_symbol(replabel_symbol(i)) + + print("| dictionary: {} types".format(len(tgt_dict))) + return cls(args, tgt_dict) + + def load_dataset(self, split, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + data_json_path = os.path.join(self.args.data, "{}.json".format(split)) + self.datasets[split] = get_asr_dataset_from_json(data_json_path, self.tgt_dict) + + def build_generator(self, models, args, **unused): + w2l_decoder = getattr(args, "w2l_decoder", None) + if w2l_decoder == "viterbi": + from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder + + return W2lViterbiDecoder(args, self.target_dictionary) + elif w2l_decoder == "kenlm": + from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder + + return W2lKenLMDecoder(args, self.target_dictionary) + elif w2l_decoder == "fairseqlm": + from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder + + return W2lFairseqLMDecoder(args, self.target_dictionary) + else: + return super().build_generator(models, args) + + @property + def target_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self.tgt_dict + + @property + def source_dictionary(self): + """Return the source :class:`~fairseq.data.Dictionary` (if applicable + for this task).""" + return None + + def max_positions(self): + """Return the max speech and sentence length allowed by the task.""" + return (self.args.max_source_positions, self.args.max_target_positions) diff --git a/fairseq/examples/speech_recognition/utils/wer_utils.py b/fairseq/examples/speech_recognition/utils/wer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cf6f3d09ba41a46ad4d7968fb3c286dd53d15c38 --- /dev/null +++ b/fairseq/examples/speech_recognition/utils/wer_utils.py @@ -0,0 +1,381 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import re +from collections import deque +from enum import Enum + +import numpy as np + + +""" + Utility modules for computation of Word Error Rate, + Alignments, as well as more granular metrics like + deletion, insersion and substitutions. +""" + + +class Code(Enum): + match = 1 + substitution = 2 + insertion = 3 + deletion = 4 + + +class Token(object): + def __init__(self, lbl="", st=np.nan, en=np.nan): + if np.isnan(st): + self.label, self.start, self.end = "", 0.0, 0.0 + else: + self.label, self.start, self.end = lbl, st, en + + +class AlignmentResult(object): + def __init__(self, refs, hyps, codes, score): + self.refs = refs # std::deque + self.hyps = hyps # std::deque + self.codes = codes # std::deque + self.score = score # float + + +def coordinate_to_offset(row, col, ncols): + return int(row * ncols + col) + + +def offset_to_row(offset, ncols): + return int(offset / ncols) + + +def offset_to_col(offset, ncols): + return int(offset % ncols) + + +def trimWhitespace(str): + return re.sub(" +", " ", re.sub(" *$", "", re.sub("^ *", "", str))) + + +def str2toks(str): + pieces = trimWhitespace(str).split(" ") + toks = [] + for p in pieces: + toks.append(Token(p, 0.0, 0.0)) + return toks + + +class EditDistance(object): + def __init__(self, time_mediated): + self.time_mediated_ = time_mediated + self.scores_ = np.nan # Eigen::Matrix + self.backtraces_ = ( + np.nan + ) # Eigen::Matrix backtraces_; + self.confusion_pairs_ = {} + + def cost(self, ref, hyp, code): + if self.time_mediated_: + if code == Code.match: + return abs(ref.start - hyp.start) + abs(ref.end - hyp.end) + elif code == Code.insertion: + return hyp.end - hyp.start + elif code == Code.deletion: + return ref.end - ref.start + else: # substitution + return abs(ref.start - hyp.start) + abs(ref.end - hyp.end) + 0.1 + else: + if code == Code.match: + return 0 + elif code == Code.insertion or code == Code.deletion: + return 3 + else: # substitution + return 4 + + def get_result(self, refs, hyps): + res = AlignmentResult(refs=deque(), hyps=deque(), codes=deque(), score=np.nan) + + num_rows, num_cols = self.scores_.shape + res.score = self.scores_[num_rows - 1, num_cols - 1] + + curr_offset = coordinate_to_offset(num_rows - 1, num_cols - 1, num_cols) + + while curr_offset != 0: + curr_row = offset_to_row(curr_offset, num_cols) + curr_col = offset_to_col(curr_offset, num_cols) + + prev_offset = self.backtraces_[curr_row, curr_col] + + prev_row = offset_to_row(prev_offset, num_cols) + prev_col = offset_to_col(prev_offset, num_cols) + + res.refs.appendleft(curr_row - 1) # Note: this was .push_front() in C++ + res.hyps.appendleft(curr_col - 1) + if curr_row - 1 == prev_row and curr_col == prev_col: + res.codes.appendleft(Code.deletion) + elif curr_row == prev_row and curr_col - 1 == prev_col: + res.codes.appendleft(Code.insertion) + else: + # assert(curr_row - 1 == prev_row and curr_col - 1 == prev_col) + ref_str = refs[res.refs[0]].label + hyp_str = hyps[res.hyps[0]].label + + if ref_str == hyp_str: + res.codes.appendleft(Code.match) + else: + res.codes.appendleft(Code.substitution) + + confusion_pair = "%s -> %s" % (ref_str, hyp_str) + if confusion_pair not in self.confusion_pairs_: + self.confusion_pairs_[confusion_pair] = 1 + else: + self.confusion_pairs_[confusion_pair] += 1 + + curr_offset = prev_offset + + return res + + def align(self, refs, hyps): + if len(refs) == 0 and len(hyps) == 0: + return np.nan + + # NOTE: we're not resetting the values in these matrices because every value + # will be overridden in the loop below. If this assumption doesn't hold, + # be sure to set all entries in self.scores_ and self.backtraces_ to 0. + self.scores_ = np.zeros((len(refs) + 1, len(hyps) + 1)) + self.backtraces_ = np.zeros((len(refs) + 1, len(hyps) + 1)) + + num_rows, num_cols = self.scores_.shape + + for i in range(num_rows): + for j in range(num_cols): + if i == 0 and j == 0: + self.scores_[i, j] = 0.0 + self.backtraces_[i, j] = 0 + continue + + if i == 0: + self.scores_[i, j] = self.scores_[i, j - 1] + self.cost( + None, hyps[j - 1], Code.insertion + ) + self.backtraces_[i, j] = coordinate_to_offset(i, j - 1, num_cols) + continue + + if j == 0: + self.scores_[i, j] = self.scores_[i - 1, j] + self.cost( + refs[i - 1], None, Code.deletion + ) + self.backtraces_[i, j] = coordinate_to_offset(i - 1, j, num_cols) + continue + + # Below here both i and j are greater than 0 + ref = refs[i - 1] + hyp = hyps[j - 1] + best_score = self.scores_[i - 1, j - 1] + ( + self.cost(ref, hyp, Code.match) + if (ref.label == hyp.label) + else self.cost(ref, hyp, Code.substitution) + ) + + prev_row = i - 1 + prev_col = j - 1 + ins = self.scores_[i, j - 1] + self.cost(None, hyp, Code.insertion) + if ins < best_score: + best_score = ins + prev_row = i + prev_col = j - 1 + + delt = self.scores_[i - 1, j] + self.cost(ref, None, Code.deletion) + if delt < best_score: + best_score = delt + prev_row = i - 1 + prev_col = j + + self.scores_[i, j] = best_score + self.backtraces_[i, j] = coordinate_to_offset( + prev_row, prev_col, num_cols + ) + + return self.get_result(refs, hyps) + + +class WERTransformer(object): + def __init__(self, hyp_str, ref_str, verbose=True): + self.ed_ = EditDistance(False) + self.id2oracle_errs_ = {} + self.utts_ = 0 + self.words_ = 0 + self.insertions_ = 0 + self.deletions_ = 0 + self.substitutions_ = 0 + + self.process(["dummy_str", hyp_str, ref_str]) + + if verbose: + print("'%s' vs '%s'" % (hyp_str, ref_str)) + self.report_result() + + def process(self, input): # std::vector&& input + if len(input) < 3: + print( + "Input must be of the form ... , got ", + len(input), + " inputs:", + ) + return None + + # Align + # std::vector hyps; + # std::vector refs; + + hyps = str2toks(input[-2]) + refs = str2toks(input[-1]) + + alignment = self.ed_.align(refs, hyps) + if alignment is None: + print("Alignment is null") + return np.nan + + # Tally errors + ins = 0 + dels = 0 + subs = 0 + for code in alignment.codes: + if code == Code.substitution: + subs += 1 + elif code == Code.insertion: + ins += 1 + elif code == Code.deletion: + dels += 1 + + # Output + row = input + row.append(str(len(refs))) + row.append(str(ins)) + row.append(str(dels)) + row.append(str(subs)) + # print(row) + + # Accumulate + kIdIndex = 0 + kNBestSep = "/" + + pieces = input[kIdIndex].split(kNBestSep) + + if len(pieces) == 0: + print( + "Error splitting ", + input[kIdIndex], + " on '", + kNBestSep, + "', got empty list", + ) + return np.nan + + id = pieces[0] + if id not in self.id2oracle_errs_: + self.utts_ += 1 + self.words_ += len(refs) + self.insertions_ += ins + self.deletions_ += dels + self.substitutions_ += subs + self.id2oracle_errs_[id] = [ins, dels, subs] + else: + curr_err = ins + dels + subs + prev_err = np.sum(self.id2oracle_errs_[id]) + if curr_err < prev_err: + self.id2oracle_errs_[id] = [ins, dels, subs] + + return 0 + + def report_result(self): + # print("---------- Summary ---------------") + if self.words_ == 0: + print("No words counted") + return + + # 1-best + best_wer = ( + 100.0 + * (self.insertions_ + self.deletions_ + self.substitutions_) + / self.words_ + ) + + print( + "\tWER = %0.2f%% (%i utts, %i words, %0.2f%% ins, " + "%0.2f%% dels, %0.2f%% subs)" + % ( + best_wer, + self.utts_, + self.words_, + 100.0 * self.insertions_ / self.words_, + 100.0 * self.deletions_ / self.words_, + 100.0 * self.substitutions_ / self.words_, + ) + ) + + def wer(self): + if self.words_ == 0: + wer = np.nan + else: + wer = ( + 100.0 + * (self.insertions_ + self.deletions_ + self.substitutions_) + / self.words_ + ) + return wer + + def stats(self): + if self.words_ == 0: + stats = {} + else: + wer = ( + 100.0 + * (self.insertions_ + self.deletions_ + self.substitutions_) + / self.words_ + ) + stats = dict( + { + "wer": wer, + "utts": self.utts_, + "numwords": self.words_, + "ins": self.insertions_, + "dels": self.deletions_, + "subs": self.substitutions_, + "confusion_pairs": self.ed_.confusion_pairs_, + } + ) + return stats + + +def calc_wer(hyp_str, ref_str): + t = WERTransformer(hyp_str, ref_str, verbose=0) + return t.wer() + + +def calc_wer_stats(hyp_str, ref_str): + t = WERTransformer(hyp_str, ref_str, verbose=0) + return t.stats() + + +def get_wer_alignment_codes(hyp_str, ref_str): + """ + INPUT: hypothesis string, reference string + OUTPUT: List of alignment codes (intermediate results from WER computation) + """ + t = WERTransformer(hyp_str, ref_str, verbose=0) + return t.ed_.align(str2toks(ref_str), str2toks(hyp_str)).codes + + +def merge_counts(x, y): + # Merge two hashes which have 'counts' as their values + # This can be used for example to merge confusion pair counts + # conf_pairs = merge_counts(conf_pairs, stats['confusion_pairs']) + for k, v in y.items(): + if k not in x: + x[k] = 0 + x[k] += v + return x diff --git a/fairseq/examples/speech_recognition/w2l_decoder.py b/fairseq/examples/speech_recognition/w2l_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..fbf2d3524ee40bd0d08b6a9560047d96e49b6045 --- /dev/null +++ b/fairseq/examples/speech_recognition/w2l_decoder.py @@ -0,0 +1,486 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Flashlight decoders. +""" + +import gc +import itertools as it +import os.path as osp +from typing import List +import warnings +from collections import deque, namedtuple + +import numpy as np +import torch +from examples.speech_recognition.data.replabels import unpack_replabels +from fairseq import tasks +from fairseq.utils import apply_to_sample +from omegaconf import open_dict +from fairseq.dataclass.utils import convert_namespace_to_omegaconf + + +try: + from flashlight.lib.text.dictionary import create_word_dict, load_words + from flashlight.lib.sequence.criterion import CpuViterbiPath, get_data_ptr_as_bytes + from flashlight.lib.text.decoder import ( + CriterionType, + LexiconDecoderOptions, + KenLM, + LM, + LMState, + SmearingMode, + Trie, + LexiconDecoder, + ) +except: + warnings.warn( + "flashlight python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/flashlight/tree/master/bindings/python" + ) + LM = object + LMState = object + + +class W2lDecoder(object): + def __init__(self, args, tgt_dict): + self.tgt_dict = tgt_dict + self.vocab_size = len(tgt_dict) + self.nbest = args.nbest + + # criterion-specific init + self.criterion_type = CriterionType.CTC + self.blank = ( + tgt_dict.index("") + if "" in tgt_dict.indices + else tgt_dict.bos() + ) + if "" in tgt_dict.indices: + self.silence = tgt_dict.index("") + elif "|" in tgt_dict.indices: + self.silence = tgt_dict.index("|") + else: + self.silence = tgt_dict.eos() + self.asg_transitions = None + + def generate(self, models, sample, **unused): + """Generate a batch of inferences.""" + # model.forward normally channels prev_output_tokens into the decoder + # separately, but SequenceGenerator directly calls model.encoder + encoder_input = { + k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens" + } + emissions = self.get_emissions(models, encoder_input) + return self.decode(emissions) + + def get_emissions(self, models, encoder_input): + """Run encoder and normalize emissions""" + model = models[0] + encoder_out = model(**encoder_input) + if hasattr(model, "get_logits"): + emissions = model.get_logits(encoder_out) # no need to normalize emissions + else: + emissions = model.get_normalized_probs(encoder_out, log_probs=True) + return emissions.transpose(0, 1).float().cpu().contiguous() + + def get_tokens(self, idxs): + """Normalize tokens by handling CTC blank, ASG replabels, etc.""" + idxs = (g[0] for g in it.groupby(idxs)) + idxs = filter(lambda x: x != self.blank, idxs) + return torch.LongTensor(list(idxs)) + + +class W2lViterbiDecoder(W2lDecoder): + def __init__(self, args, tgt_dict): + super().__init__(args, tgt_dict) + + def decode(self, emissions): + B, T, N = emissions.size() + hypos = [] + if self.asg_transitions is None: + transitions = torch.FloatTensor(N, N).zero_() + else: + transitions = torch.FloatTensor(self.asg_transitions).view(N, N) + viterbi_path = torch.IntTensor(B, T) + workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N)) + CpuViterbiPath.compute( + B, + T, + N, + get_data_ptr_as_bytes(emissions), + get_data_ptr_as_bytes(transitions), + get_data_ptr_as_bytes(viterbi_path), + get_data_ptr_as_bytes(workspace), + ) + return [ + [{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}] + for b in range(B) + ] + + +class W2lKenLMDecoder(W2lDecoder): + def __init__(self, args, tgt_dict): + super().__init__(args, tgt_dict) + + self.unit_lm = getattr(args, "unit_lm", False) + + if args.lexicon: + self.lexicon = load_words(args.lexicon) + self.word_dict = create_word_dict(self.lexicon) + self.unk_word = self.word_dict.get_index("") + + self.lm = KenLM(args.kenlm_model, self.word_dict) + self.trie = Trie(self.vocab_size, self.silence) + + start_state = self.lm.start(False) + for i, (word, spellings) in enumerate(self.lexicon.items()): + word_idx = self.word_dict.get_index(word) + _, score = self.lm.score(start_state, word_idx) + for spelling in spellings: + spelling_idxs = [tgt_dict.index(token) for token in spelling] + assert ( + tgt_dict.unk() not in spelling_idxs + ), f"{spelling} {spelling_idxs}" + self.trie.insert(spelling_idxs, word_idx, score) + self.trie.smear(SmearingMode.MAX) + + self.decoder_opts = LexiconDecoderOptions( + beam_size=args.beam, + beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), + beam_threshold=args.beam_threshold, + lm_weight=args.lm_weight, + word_score=args.word_score, + unk_score=args.unk_weight, + sil_score=args.sil_weight, + log_add=False, + criterion_type=self.criterion_type, + ) + + if self.asg_transitions is None: + N = 768 + # self.asg_transitions = torch.FloatTensor(N, N).zero_() + self.asg_transitions = [] + + self.decoder = LexiconDecoder( + self.decoder_opts, + self.trie, + self.lm, + self.silence, + self.blank, + self.unk_word, + self.asg_transitions, + self.unit_lm, + ) + else: + assert args.unit_lm, "lexicon free decoding can only be done with a unit language model" + from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions + + d = {w: [[w]] for w in tgt_dict.symbols} + self.word_dict = create_word_dict(d) + self.lm = KenLM(args.kenlm_model, self.word_dict) + self.decoder_opts = LexiconFreeDecoderOptions( + beam_size=args.beam, + beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), + beam_threshold=args.beam_threshold, + lm_weight=args.lm_weight, + sil_score=args.sil_weight, + log_add=False, + criterion_type=self.criterion_type, + ) + self.decoder = LexiconFreeDecoder( + self.decoder_opts, self.lm, self.silence, self.blank, [] + ) + + def get_timesteps(self, token_idxs: List[int]) -> List[int]: + """Returns frame numbers corresponding to every non-blank token. + + Parameters + ---------- + token_idxs : List[int] + IDs of decoded tokens. + + Returns + ------- + List[int] + Frame numbers corresponding to every non-blank token. + """ + timesteps = [] + for i, token_idx in enumerate(token_idxs): + if token_idx == self.blank: + continue + if i == 0 or token_idx != token_idxs[i-1]: + timesteps.append(i) + return timesteps + + def decode(self, emissions): + B, T, N = emissions.size() + hypos = [] + for b in range(B): + emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) + results = self.decoder.decode(emissions_ptr, T, N) + + nbest_results = results[: self.nbest] + hypos.append( + [ + { + "tokens": self.get_tokens(result.tokens), + "score": result.score, + "timesteps": self.get_timesteps(result.tokens), + "words": [ + self.word_dict.get_entry(x) for x in result.words if x >= 0 + ], + } + for result in nbest_results + ] + ) + return hypos + + +FairseqLMState = namedtuple("FairseqLMState", ["prefix", "incremental_state", "probs"]) + + +class FairseqLM(LM): + def __init__(self, dictionary, model): + LM.__init__(self) + self.dictionary = dictionary + self.model = model + self.unk = self.dictionary.unk() + + self.save_incremental = False # this currently does not work properly + self.max_cache = 20_000 + + model.cuda() + model.eval() + model.make_generation_fast_() + + self.states = {} + self.stateq = deque() + + def start(self, start_with_nothing): + state = LMState() + prefix = torch.LongTensor([[self.dictionary.eos()]]) + incremental_state = {} if self.save_incremental else None + with torch.no_grad(): + res = self.model(prefix.cuda(), incremental_state=incremental_state) + probs = self.model.get_normalized_probs(res, log_probs=True, sample=None) + + if incremental_state is not None: + incremental_state = apply_to_sample(lambda x: x.cpu(), incremental_state) + self.states[state] = FairseqLMState( + prefix.numpy(), incremental_state, probs[0, -1].cpu().numpy() + ) + self.stateq.append(state) + + return state + + def score(self, state: LMState, token_index: int, no_cache: bool = False): + """ + Evaluate language model based on the current lm state and new word + Parameters: + ----------- + state: current lm state + token_index: index of the word + (can be lexicon index then you should store inside LM the + mapping between indices of lexicon and lm, or lm index of a word) + + Returns: + -------- + (LMState, float): pair of (new state, score for the current word) + """ + curr_state = self.states[state] + + def trim_cache(targ_size): + while len(self.stateq) > targ_size: + rem_k = self.stateq.popleft() + rem_st = self.states[rem_k] + rem_st = FairseqLMState(rem_st.prefix, None, None) + self.states[rem_k] = rem_st + + if curr_state.probs is None: + new_incremental_state = ( + curr_state.incremental_state.copy() + if curr_state.incremental_state is not None + else None + ) + with torch.no_grad(): + if new_incremental_state is not None: + new_incremental_state = apply_to_sample( + lambda x: x.cuda(), new_incremental_state + ) + elif self.save_incremental: + new_incremental_state = {} + + res = self.model( + torch.from_numpy(curr_state.prefix).cuda(), + incremental_state=new_incremental_state, + ) + probs = self.model.get_normalized_probs( + res, log_probs=True, sample=None + ) + + if new_incremental_state is not None: + new_incremental_state = apply_to_sample( + lambda x: x.cpu(), new_incremental_state + ) + + curr_state = FairseqLMState( + curr_state.prefix, new_incremental_state, probs[0, -1].cpu().numpy() + ) + + if not no_cache: + self.states[state] = curr_state + self.stateq.append(state) + + score = curr_state.probs[token_index].item() + + trim_cache(self.max_cache) + + outstate = state.child(token_index) + if outstate not in self.states and not no_cache: + prefix = np.concatenate( + [curr_state.prefix, torch.LongTensor([[token_index]])], -1 + ) + incr_state = curr_state.incremental_state + + self.states[outstate] = FairseqLMState(prefix, incr_state, None) + + if token_index == self.unk: + score = float("-inf") + + return outstate, score + + def finish(self, state: LMState): + """ + Evaluate eos for language model based on the current lm state + + Returns: + -------- + (LMState, float): pair of (new state, score for the current word) + """ + return self.score(state, self.dictionary.eos()) + + def empty_cache(self): + self.states = {} + self.stateq = deque() + gc.collect() + + +class W2lFairseqLMDecoder(W2lDecoder): + def __init__(self, args, tgt_dict): + super().__init__(args, tgt_dict) + + self.unit_lm = getattr(args, "unit_lm", False) + + self.lexicon = load_words(args.lexicon) if args.lexicon else None + self.idx_to_wrd = {} + + checkpoint = torch.load(args.kenlm_model, map_location="cpu") + + if "cfg" in checkpoint and checkpoint["cfg"] is not None: + lm_args = checkpoint["cfg"] + else: + lm_args = convert_namespace_to_omegaconf(checkpoint["args"]) + + with open_dict(lm_args.task): + lm_args.task.data = osp.dirname(args.kenlm_model) + + task = tasks.setup_task(lm_args.task) + model = task.build_model(lm_args.model) + model.load_state_dict(checkpoint["model"], strict=False) + + self.trie = Trie(self.vocab_size, self.silence) + + self.word_dict = task.dictionary + self.unk_word = self.word_dict.unk() + self.lm = FairseqLM(self.word_dict, model) + + if self.lexicon: + start_state = self.lm.start(False) + for i, (word, spellings) in enumerate(self.lexicon.items()): + if self.unit_lm: + word_idx = i + self.idx_to_wrd[i] = word + score = 0 + else: + word_idx = self.word_dict.index(word) + _, score = self.lm.score(start_state, word_idx, no_cache=True) + + for spelling in spellings: + spelling_idxs = [tgt_dict.index(token) for token in spelling] + assert ( + tgt_dict.unk() not in spelling_idxs + ), f"{spelling} {spelling_idxs}" + self.trie.insert(spelling_idxs, word_idx, score) + self.trie.smear(SmearingMode.MAX) + + self.decoder_opts = LexiconDecoderOptions( + beam_size=args.beam, + beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), + beam_threshold=args.beam_threshold, + lm_weight=args.lm_weight, + word_score=args.word_score, + unk_score=args.unk_weight, + sil_score=args.sil_weight, + log_add=False, + criterion_type=self.criterion_type, + ) + + self.decoder = LexiconDecoder( + self.decoder_opts, + self.trie, + self.lm, + self.silence, + self.blank, + self.unk_word, + [], + self.unit_lm, + ) + else: + assert args.unit_lm, "lexicon free decoding can only be done with a unit language model" + from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions + + d = {w: [[w]] for w in tgt_dict.symbols} + self.word_dict = create_word_dict(d) + self.lm = KenLM(args.kenlm_model, self.word_dict) + self.decoder_opts = LexiconFreeDecoderOptions( + beam_size=args.beam, + beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), + beam_threshold=args.beam_threshold, + lm_weight=args.lm_weight, + sil_score=args.sil_weight, + log_add=False, + criterion_type=self.criterion_type, + ) + self.decoder = LexiconFreeDecoder( + self.decoder_opts, self.lm, self.silence, self.blank, [] + ) + + def decode(self, emissions): + B, T, N = emissions.size() + hypos = [] + + def idx_to_word(idx): + if self.unit_lm: + return self.idx_to_wrd[idx] + else: + return self.word_dict[idx] + + def make_hypo(result): + hypo = {"tokens": self.get_tokens(result.tokens), "score": result.score} + if self.lexicon: + hypo["words"] = [idx_to_word(x) for x in result.words if x >= 0] + return hypo + + for b in range(B): + emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) + results = self.decoder.decode(emissions_ptr, T, N) + + nbest_results = results[: self.nbest] + hypos.append([make_hypo(result) for result in nbest_results]) + self.lm.empty_cache() + + return hypos diff --git a/fairseq/examples/speech_synthesis/README.md b/fairseq/examples/speech_synthesis/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a31e7f68bd3128494d70c0a4dfa75b9d54aa4288 --- /dev/null +++ b/fairseq/examples/speech_synthesis/README.md @@ -0,0 +1,38 @@ +Speech Synthesis (S^2) +=== +[https://arxiv.org/abs/2109.06912](https://arxiv.org/abs/2109.06912) + +Speech synthesis with fairseq. + +## Features + +- Autoregressive and non-autoregressive models +- Multi-speaker synthesis +- Audio preprocessing (denoising, VAD, etc.) for less curated data +- Automatic metrics for model development +- Similar data configuration as [S2T](../speech_to_text/README.md) + + +## Examples +- [Single-speaker synthesis on LJSpeech](docs/ljspeech_example.md) +- [Multi-speaker synthesis on VCTK](docs/vctk_example.md) +- [Multi-speaker synthesis on Common Voice](docs/common_voice_example.md) + + +## Citation +Please cite as: +``` +@article{wang2021fairseqs2, + title={fairseq S\^{} 2: A Scalable and Integrable Speech Synthesis Toolkit}, + author={Wang, Changhan and Hsu, Wei-Ning and Adi, Yossi and Polyak, Adam and Lee, Ann and Chen, Peng-Jen and Gu, Jiatao and Pino, Juan}, + journal={arXiv preprint arXiv:2109.06912}, + year={2021} +} + +@inproceedings{ott2019fairseq, + title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling}, + author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli}, + booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations}, + year = {2019}, +} +``` diff --git a/fairseq/examples/speech_synthesis/__init__.py b/fairseq/examples/speech_synthesis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6264236915a7269a4d920ee8213004374dd86a9a --- /dev/null +++ b/fairseq/examples/speech_synthesis/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/fairseq/examples/speech_synthesis/data_utils.py b/fairseq/examples/speech_synthesis/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3b2d079a9a8273a7331f2a07c46eaefe568947cf --- /dev/null +++ b/fairseq/examples/speech_synthesis/data_utils.py @@ -0,0 +1,344 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import io +import os +from pathlib import Path +from typing import Optional, List, Dict +import zipfile +import tempfile +from dataclasses import dataclass +from itertools import groupby + +import torch +import torch.nn.functional as F +import numpy as np +from tqdm import tqdm + +from examples.speech_to_text.data_utils import load_tsv_to_dicts +from fairseq.data.audio.audio_utils import ( + TTSSpectrogram, TTSMelScale, parse_path, read_from_stored_zip, is_npy_data +) + + +def trim_or_pad_to_target_length( + data_1d_or_2d: np.ndarray, target_length: int +) -> np.ndarray: + assert len(data_1d_or_2d.shape) in {1, 2} + delta = data_1d_or_2d.shape[0] - target_length + if delta >= 0: # trim if being longer + data_1d_or_2d = data_1d_or_2d[: target_length] + else: # pad if being shorter + if len(data_1d_or_2d.shape) == 1: + data_1d_or_2d = np.concatenate( + [data_1d_or_2d, np.zeros(-delta)], axis=0 + ) + else: + data_1d_or_2d = np.concatenate( + [data_1d_or_2d, np.zeros((-delta, data_1d_or_2d.shape[1]))], + axis=0 + ) + return data_1d_or_2d + + +def extract_logmel_spectrogram( + waveform: torch.Tensor, sample_rate: int, + output_path: Optional[Path] = None, win_length: int = 1024, + hop_length: int = 256, n_fft: int = 1024, + win_fn: callable = torch.hann_window, n_mels: int = 80, + f_min: float = 0., f_max: float = 8000, eps: float = 1e-5, + overwrite: bool = False, target_length: Optional[int] = None +): + if output_path is not None and output_path.is_file() and not overwrite: + return + + spectrogram_transform = TTSSpectrogram( + n_fft=n_fft, win_length=win_length, hop_length=hop_length, + window_fn=win_fn + ) + mel_scale_transform = TTSMelScale( + n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max, + n_stft=n_fft // 2 + 1 + ) + spectrogram = spectrogram_transform(waveform) + mel_spec = mel_scale_transform(spectrogram) + logmel_spec = torch.clamp(mel_spec, min=eps).log() + assert len(logmel_spec.shape) == 3 and logmel_spec.shape[0] == 1 + logmel_spec = logmel_spec.squeeze().t() # D x T -> T x D + if target_length is not None: + logmel_spec = trim_or_pad_to_target_length(logmel_spec, target_length) + + if output_path is not None: + np.save(output_path.as_posix(), logmel_spec) + else: + return logmel_spec + + +def extract_pitch( + waveform: torch.Tensor, sample_rate: int, + output_path: Optional[Path] = None, hop_length: int = 256, + log_scale: bool = True, phoneme_durations: Optional[List[int]] = None +): + if output_path is not None and output_path.is_file(): + return + + try: + import pyworld + except ImportError: + raise ImportError("Please install PyWORLD: pip install pyworld") + + _waveform = waveform.squeeze(0).double().numpy() + pitch, t = pyworld.dio( + _waveform, sample_rate, frame_period=hop_length / sample_rate * 1000 + ) + pitch = pyworld.stonemask(_waveform, pitch, t, sample_rate) + + if phoneme_durations is not None: + pitch = trim_or_pad_to_target_length(pitch, sum(phoneme_durations)) + try: + from scipy.interpolate import interp1d + except ImportError: + raise ImportError("Please install SciPy: pip install scipy") + nonzero_ids = np.where(pitch != 0)[0] + if len(nonzero_ids) == 0: + print((f"{output_path} has all empty values in the pitch contour")) + return + elif len(nonzero_ids) == 1: + print((f"{output_path} has only one non-zero values in the pitch contour")) + return + else: + interp_fn = interp1d( + nonzero_ids, + pitch[nonzero_ids], + fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]), + bounds_error=False, + ) + pitch = interp_fn(np.arange(0, len(pitch))) + d_cumsum = np.cumsum(np.concatenate([np.array([0]), phoneme_durations])) + pitch = np.array( + [ + np.mean(pitch[d_cumsum[i-1]: d_cumsum[i]]) + for i in range(1, len(d_cumsum)) + ] + ) + assert len(pitch) == len(phoneme_durations) + + if log_scale: + pitch = np.log(pitch + 1) + + if output_path is not None: + np.save(output_path.as_posix(), pitch) + else: + return pitch + + +def extract_energy( + waveform: torch.Tensor, output_path: Optional[Path] = None, + hop_length: int = 256, n_fft: int = 1024, log_scale: bool = True, + phoneme_durations: Optional[List[int]] = None +): + if output_path is not None and output_path.is_file(): + return + + assert len(waveform.shape) == 2 and waveform.shape[0] == 1 + waveform = waveform.view(1, 1, waveform.shape[1]) + waveform = F.pad( + waveform.unsqueeze(1), [n_fft // 2, n_fft // 2, 0, 0], + mode="reflect" + ) + waveform = waveform.squeeze(1) + + fourier_basis = np.fft.fft(np.eye(n_fft)) + cutoff = int((n_fft / 2 + 1)) + fourier_basis = np.vstack( + [np.real(fourier_basis[:cutoff, :]), + np.imag(fourier_basis[:cutoff, :])] + ) + + forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) + forward_transform = F.conv1d( + waveform, forward_basis, stride=hop_length, padding=0 + ) + + real_part = forward_transform[:, :cutoff, :] + imag_part = forward_transform[:, cutoff:, :] + magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) + energy = torch.norm(magnitude, dim=1).squeeze(0).numpy() + + if phoneme_durations is not None: + energy = trim_or_pad_to_target_length(energy, sum(phoneme_durations)) + d_cumsum = np.cumsum(np.concatenate([np.array([0]), phoneme_durations])) + energy = np.array( + [ + np.mean(energy[d_cumsum[i - 1]: d_cumsum[i]]) + for i in range(1, len(d_cumsum)) + ] + ) + assert len(energy) == len(phoneme_durations) + + if log_scale: + energy = np.log(energy + 1) + + if output_path is not None: + np.save(output_path.as_posix(), energy) + else: + return energy + + +def get_global_cmvn(feature_root: Path, output_path: Optional[Path] = None): + mean_x, mean_x2, n_frames = None, None, 0 + feature_paths = feature_root.glob("*.npy") + for p in tqdm(feature_paths): + with open(p, 'rb') as f: + frames = np.load(f).squeeze() + + n_frames += frames.shape[0] + + cur_mean_x = frames.sum(axis=0) + if mean_x is None: + mean_x = cur_mean_x + else: + mean_x += cur_mean_x + + cur_mean_x2 = (frames ** 2).sum(axis=0) + if mean_x2 is None: + mean_x2 = cur_mean_x2 + else: + mean_x2 += cur_mean_x2 + + mean_x /= n_frames + mean_x2 /= n_frames + var_x = mean_x2 - mean_x ** 2 + std_x = np.sqrt(np.maximum(var_x, 1e-10)) + + if output_path is not None: + with open(output_path, 'wb') as f: + np.savez(f, mean=mean_x, std=std_x) + else: + return {"mean": mean_x, "std": std_x} + + +def ipa_phonemize(text, lang="en-us", use_g2p=False): + if use_g2p: + assert lang == "en-us", "g2pE phonemizer only works for en-us" + try: + from g2p_en import G2p + g2p = G2p() + return " ".join("|" if p == " " else p for p in g2p(text)) + except ImportError: + raise ImportError( + "Please install phonemizer: pip install g2p_en" + ) + else: + try: + from phonemizer import phonemize + from phonemizer.separator import Separator + return phonemize( + text, backend='espeak', language=lang, + separator=Separator(word="| ", phone=" ") + ) + except ImportError: + raise ImportError( + "Please install phonemizer: pip install phonemizer" + ) + + +@dataclass +class ForceAlignmentInfo(object): + tokens: List[str] + frame_durations: List[int] + start_sec: Optional[float] + end_sec: Optional[float] + + +def get_mfa_alignment_by_sample_id( + textgrid_zip_path: str, sample_id: str, sample_rate: int, + hop_length: int, silence_phones: List[str] = ("sil", "sp", "spn") +) -> ForceAlignmentInfo: + try: + import tgt + except ImportError: + raise ImportError("Please install TextGridTools: pip install tgt") + + filename = f"{sample_id}.TextGrid" + out_root = Path(tempfile.gettempdir()) + tgt_path = out_root / filename + with zipfile.ZipFile(textgrid_zip_path) as f_zip: + f_zip.extract(filename, path=out_root) + textgrid = tgt.io.read_textgrid(tgt_path.as_posix()) + os.remove(tgt_path) + + phones, frame_durations = [], [] + start_sec, end_sec, end_idx = 0, 0, 0 + for t in textgrid.get_tier_by_name("phones")._objects: + s, e, p = t.start_time, t.end_time, t.text + # Trim leading silences + if len(phones) == 0: + if p in silence_phones: + continue + else: + start_sec = s + phones.append(p) + if p not in silence_phones: + end_sec = e + end_idx = len(phones) + r = sample_rate / hop_length + frame_durations.append(int(np.round(e * r) - np.round(s * r))) + # Trim tailing silences + phones = phones[:end_idx] + frame_durations = frame_durations[:end_idx] + + return ForceAlignmentInfo( + tokens=phones, frame_durations=frame_durations, start_sec=start_sec, + end_sec=end_sec + ) + + +def get_mfa_alignment( + textgrid_zip_path: str, sample_ids: List[str], sample_rate: int, + hop_length: int +) -> Dict[str, ForceAlignmentInfo]: + return { + i: get_mfa_alignment_by_sample_id( + textgrid_zip_path, i, sample_rate, hop_length + ) for i in tqdm(sample_ids) + } + + +def get_unit_alignment( + id_to_unit_tsv_path: str, sample_ids: List[str] +) -> Dict[str, ForceAlignmentInfo]: + id_to_units = { + e["id"]: e["units"] for e in load_tsv_to_dicts(id_to_unit_tsv_path) + } + id_to_units = {i: id_to_units[i].split() for i in sample_ids} + id_to_units_collapsed = { + i: [uu for uu, _ in groupby(u)] for i, u in id_to_units.items() + } + id_to_durations = { + i: [len(list(g)) for _, g in groupby(u)] for i, u in id_to_units.items() + } + + return { + i: ForceAlignmentInfo( + tokens=id_to_units_collapsed[i], frame_durations=id_to_durations[i], + start_sec=None, end_sec=None + ) + for i in sample_ids + } + + +def get_feature_value_min_max(feature_paths: List[str]): + v_min, v_max = 1e-8, -1e-8 + for p in tqdm(feature_paths): + _path, slice_ptr = parse_path(p) + assert len(slice_ptr) == 2 + byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1]) + assert is_npy_data(byte_data) + path_or_fp = io.BytesIO(byte_data) + features = np.load(path_or_fp).squeeze() + v_min = min(v_min, features.min().item()) + v_max = max(v_max, features.max().item()) + return v_min, v_max diff --git a/fairseq/examples/speech_synthesis/docs/common_voice_example.md b/fairseq/examples/speech_synthesis/docs/common_voice_example.md new file mode 100644 index 0000000000000000000000000000000000000000..1c0eef69a0adcfb60e6d89c6df63311fbb1eb4aa --- /dev/null +++ b/fairseq/examples/speech_synthesis/docs/common_voice_example.md @@ -0,0 +1,67 @@ +[[Back]](..) + +# Common Voice + +[Common Voice](https://commonvoice.mozilla.org/en/datasets) is a public domain speech corpus with 11.2K hours of read +speech in 76 languages (the latest version 7.0). We provide examples for building +[Transformer](https://arxiv.org/abs/1809.08895) models on this dataset. + + +## Data preparation +[Download](https://commonvoice.mozilla.org/en/datasets) and unpack Common Voice v4 to a path `${DATA_ROOT}/${LANG_ID}`. +Create splits and generate audio manifests with +```bash +python -m examples.speech_synthesis.preprocessing.get_common_voice_audio_manifest \ + --data-root ${DATA_ROOT} \ + --lang ${LANG_ID} \ + --output-manifest-root ${AUDIO_MANIFEST_ROOT} --convert-to-wav +``` + +To denoise audio and trim leading/trailing silence using signal processing based VAD, run +```bash +for SPLIT in dev test train; do + python -m examples.speech_synthesis.preprocessing.denoise_and_vad_audio \ + --audio-manifest ${AUDIO_MANIFEST_ROOT}/${SPLIT}.audio.tsv \ + --output-dir ${PROCESSED_DATA_ROOT} \ + --denoise --vad --vad-agg-level 2 +done +``` + +which generates a new audio TSV manifest under `${PROCESSED_DATA_ROOT}` with updated path to the processed audio and +a new column for SNR. + +To do filtering by CER, follow the [Automatic Evaluation](../docs/ljspeech_example.md#automatic-evaluation) section to +run ASR model (add `--eval-target` to `get_eval_manifest` for evaluation on the reference audio; add `--err-unit char` +to `eval_asr` to compute CER instead of WER). The example-level CER is saved to +`${EVAL_OUTPUT_ROOT}/uer_cer.${SPLIT}.tsv`. + +Then, extract log-Mel spectrograms, generate feature manifest and create data configuration YAML with +```bash +python -m examples.speech_synthesis.preprocessing.get_feature_manifest \ + --audio-manifest-root ${AUDIO_MANIFEST_ROOT} \ + --output-root ${FEATURE_MANIFEST_ROOT} \ + --ipa-vocab --lang ${LANG_ID} \ + --snr-threshold 15 \ + --cer-threshold 0.1 --cer-tsv-path ${EVAL_OUTPUT_ROOT}/uer_cer.${SPLIT}.tsv +``` +where we use phoneme inputs (`--ipa-vocab`) as example. For sample filtering, we set the SNR and CER threshold +to 15 and 10%, respectively. + + +## Training +(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#transformer).) + + +## Inference +(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#inference).) + +## Automatic Evaluation +(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#automatic-evaluation).) + +## Results + +| Language | Speakers | --arch | Params | Test MCD | Model | +|---|---|---|---|---|---| +| English | 200 | tts_transformer | 54M | 3.8 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2/cv4_en200_transformer_phn.tar) | + +[[Back]](..) diff --git a/fairseq/examples/speech_synthesis/docs/ljspeech_example.md b/fairseq/examples/speech_synthesis/docs/ljspeech_example.md new file mode 100644 index 0000000000000000000000000000000000000000..836c30d6d5fbd0af3f57f58e903a579612587f25 --- /dev/null +++ b/fairseq/examples/speech_synthesis/docs/ljspeech_example.md @@ -0,0 +1,137 @@ +[[Back]](..) + +# LJSpeech + +[LJSpeech](https://keithito.com/LJ-Speech-Dataset) is a public domain TTS +corpus with around 24 hours of English speech sampled at 22.05kHz. We provide examples for building +[Transformer](https://arxiv.org/abs/1809.08895) and [FastSpeech 2](https://arxiv.org/abs/2006.04558) +models on this dataset. + + +## Data preparation + +Download data, create splits and generate audio manifests with +```bash +python -m examples.speech_synthesis.preprocessing.get_ljspeech_audio_manifest \ + --output-data-root ${AUDIO_DATA_ROOT} \ + --output-manifest-root ${AUDIO_MANIFEST_ROOT} +``` + +Then, extract log-Mel spectrograms, generate feature manifest and create data configuration YAML with +```bash +python -m examples.speech_synthesis.preprocessing.get_feature_manifest \ + --audio-manifest-root ${AUDIO_MANIFEST_ROOT} \ + --output-root ${FEATURE_MANIFEST_ROOT} \ + --ipa-vocab --use-g2p +``` +where we use phoneme inputs (`--ipa-vocab --use-g2p`) as example. + +FastSpeech 2 additionally requires frame durations, pitch and energy as auxiliary training targets. +Add `--add-fastspeech-targets` to include these fields in the feature manifests. We get frame durations either from +phoneme-level force-alignment or frame-level pseudo-text unit sequence. They should be pre-computed and specified via: +- `--textgrid-zip ${TEXT_GRID_ZIP_PATH}` for a ZIP file, inside which there is one + [TextGrid](https://www.fon.hum.uva.nl/praat/manual/TextGrid.html) file per sample to provide force-alignment info. +- `--id-to-units-tsv ${ID_TO_UNIT_TSV}` for a TSV file, where there are 2 columns for sample ID and + space-delimited pseudo-text unit sequence, respectively. + +For your convenience, we provide pre-computed +[force-alignment](https://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_mfa.zip) from +[Montreal Forced Aligner](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) and +[pseudo-text units](s3://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_hubert.tsv) from +[HuBERT](https://github.com/pytorch/fairseq/tree/main/examples/hubert). You can also generate them by yourself using +a different software or model. + + +## Training +#### Transformer +```bash +fairseq-train ${FEATURE_MANIFEST_ROOT} --save-dir ${SAVE_DIR} \ + --config-yaml config.yaml --train-subset train --valid-subset dev \ + --num-workers 4 --max-tokens 30000 --max-update 200000 \ + --task text_to_speech --criterion tacotron2 --arch tts_transformer \ + --clip-norm 5.0 --n-frames-per-step 4 --bce-pos-weight 5.0 \ + --dropout 0.1 --attention-dropout 0.1 --activation-dropout 0.1 \ + --encoder-normalize-before --decoder-normalize-before \ + --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ + --seed 1 --update-freq 8 --eval-inference --best-checkpoint-metric mcd_loss +``` +where `SAVE_DIR` is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to +update it accordingly when using more than 1 GPU. + +#### FastSpeech2 +```bash +fairseq-train ${FEATURE_MANIFEST_ROOT} --save-dir ${SAVE_DIR} \ + --config-yaml config.yaml --train-subset train --valid-subset dev \ + --num-workers 4 --max-sentences 6 --max-update 200000 \ + --task text_to_speech --criterion fastspeech2 --arch fastspeech2 \ + --clip-norm 5.0 --n-frames-per-step 1 \ + --dropout 0.1 --attention-dropout 0.1 \ + --optimizer adam --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ + --seed 1 --update-freq 8 --eval-inference --best-checkpoint-metric mcd_loss +``` + + +## Inference +Average the last 5 checkpoints, generate the test split spectrogram and waveform using the default Griffin-Lim vocoder: +```bash +SPLIT=test +CHECKPOINT_NAME=avg_last_5 +CHECKPOINT_PATH=${SAVE_DIR}/checkpoint_${CHECKPOINT_NAME}.pt +python scripts/average_checkpoints.py --inputs ${SAVE_DIR} \ + --num-epoch-checkpoints 5 \ + --output ${CHECKPOINT_PATH} + +python -m examples.speech_synthesis.generate_waveform ${FEATURE_MANIFEST_ROOT} \ + --config-yaml config.yaml --gen-subset ${SPLIT} --task text_to_speech \ + --path ${CHECKPOINT_PATH} --max-tokens 50000 --spec-bwd-max-iter 32 \ + --dump-waveforms +``` +which dumps files (waveform, feature, attention plot, etc.) to `${SAVE_DIR}/generate-${CHECKPOINT_NAME}-${SPLIT}`. To +re-synthesize target waveforms for automatic evaluation, add `--dump-target`. + +## Automatic Evaluation +To start with, generate the manifest for synthetic speech, which will be taken as inputs by evaluation scripts. +```bash +python -m examples.speech_synthesis.evaluation.get_eval_manifest \ + --generation-root ${SAVE_DIR}/generate-${CHECKPOINT_NAME}-${SPLIT} \ + --audio-manifest ${AUDIO_MANIFEST_ROOT}/${SPLIT}.audio.tsv \ + --output-path ${EVAL_OUTPUT_ROOT}/eval.tsv \ + --vocoder griffin_lim --sample-rate 22050 --audio-format flac \ + --use-resynthesized-target +``` +Speech recognition (ASR) models usually operate at lower sample rates (e.g. 16kHz). For the WER/CER metric, +you may need to resample the audios accordingly --- add `--output-sample-rate 16000` for `generate_waveform.py` and +use `--sample-rate 16000` for `get_eval_manifest.py`. + + +#### WER/CER metric +We use wav2vec 2.0 ASR model as example. [Download](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec) +the model checkpoint and dictionary, then compute WER/CER with +```bash +python -m examples.speech_synthesis.evaluation.eval_asr \ + --audio-header syn --text-header text --err-unit char --split ${SPLIT} \ + --w2v-ckpt ${WAV2VEC2_CHECKPOINT_PATH} --w2v-dict-dir ${WAV2VEC2_DICT_DIR} \ + --raw-manifest ${EVAL_OUTPUT_ROOT}/eval_16khz.tsv --asr-dir ${EVAL_OUTPUT_ROOT}/asr +``` + +#### MCD/MSD metric +```bash +python -m examples.speech_synthesis.evaluation.eval_sp \ + ${EVAL_OUTPUT_ROOT}/eval.tsv --mcd --msd +``` + +#### F0 metrics +```bash +python -m examples.speech_synthesis.evaluation.eval_f0 \ + ${EVAL_OUTPUT_ROOT}/eval.tsv --gpe --vde --ffe +``` + + +## Results + +| --arch | Params | Test MCD | Model | +|---|---|---|---| +| tts_transformer | 54M | 3.8 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_transformer_phn.tar) | +| fastspeech2 | 41M | 3.8 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_fastspeech2_phn.tar) | + +[[Back]](..) diff --git a/fairseq/examples/speech_synthesis/docs/vctk_example.md b/fairseq/examples/speech_synthesis/docs/vctk_example.md new file mode 100644 index 0000000000000000000000000000000000000000..6808256d44ef08b0350aa4d74fa062673990137a --- /dev/null +++ b/fairseq/examples/speech_synthesis/docs/vctk_example.md @@ -0,0 +1,61 @@ +[[Back]](..) + +# VCTK + +[VCTK](https://datashare.ed.ac.uk/handle/10283/3443) is an open English speech corpus. We provide examples +for building [Transformer](https://arxiv.org/abs/1809.08895) models on this dataset. + + +## Data preparation +Download data, create splits and generate audio manifests with +```bash +python -m examples.speech_synthesis.preprocessing.get_vctk_audio_manifest \ + --output-data-root ${AUDIO_DATA_ROOT} \ + --output-manifest-root ${AUDIO_MANIFEST_ROOT} +``` + +To denoise audio and trim leading/trailing silence using signal processing based VAD, run +```bash +for SPLIT in dev test train; do + python -m examples.speech_synthesis.preprocessing.denoise_and_vad_audio \ + --audio-manifest ${AUDIO_MANIFEST_ROOT}/${SPLIT}.audio.tsv \ + --output-dir ${PROCESSED_DATA_ROOT} \ + --denoise --vad --vad-agg-level 3 +done +``` +which generates a new audio TSV manifest under `${PROCESSED_DATA_ROOT}` with updated path to the processed audio and +a new column for SNR. + +To do filtering by CER, follow the [Automatic Evaluation](../docs/ljspeech_example.md#automatic-evaluation) section to +run ASR model (add `--eval-target` to `get_eval_manifest` for evaluation on the reference audio; add `--err-unit char` +to `eval_asr` to compute CER instead of WER). The example-level CER is saved to +`${EVAL_OUTPUT_ROOT}/uer_cer.${SPLIT}.tsv`. + +Then, extract log-Mel spectrograms, generate feature manifest and create data configuration YAML with +```bash +python -m examples.speech_synthesis.preprocessing.get_feature_manifest \ + --audio-manifest-root ${PROCESSED_DATA_ROOT} \ + --output-root ${FEATURE_MANIFEST_ROOT} \ + --ipa-vocab --use-g2p \ + --snr-threshold 15 \ + --cer-threshold 0.1 --cer-tsv-path ${EVAL_OUTPUT_ROOT}/uer_cer.${SPLIT}.tsv +``` +where we use phoneme inputs (`--ipa-vocab --use-g2p`) as example. For sample filtering, we set the SNR and CER threshold +to 15 and 10%, respectively. + +## Training +(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#transformer).) + +## Inference +(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#inference).) + +## Automatic Evaluation +(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#automatic-evaluation).) + +## Results + +| --arch | Params | Test MCD | Model | +|---|---|---|---| +| tts_transformer | 54M | 3.4 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2/vctk_transformer_phn.tar) | + +[[Back]](..) diff --git a/fairseq/examples/speech_synthesis/evaluation/__init__.py b/fairseq/examples/speech_synthesis/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6264236915a7269a4d920ee8213004374dd86a9a --- /dev/null +++ b/fairseq/examples/speech_synthesis/evaluation/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/fairseq/examples/speech_synthesis/evaluation/eval_asr.py b/fairseq/examples/speech_synthesis/evaluation/eval_asr.py new file mode 100644 index 0000000000000000000000000000000000000000..005a11bfb34ca477ad9e133acd60f249e66cda47 --- /dev/null +++ b/fairseq/examples/speech_synthesis/evaluation/eval_asr.py @@ -0,0 +1,128 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import editdistance +import re +import shutil +import soundfile as sf +import subprocess +from pathlib import Path + +from examples.speech_to_text.data_utils import load_tsv_to_dicts + + +def preprocess_text(text): + text = "|".join(re.sub(r"[^A-Z' ]", " ", text.upper()).split()) + text = " ".join(text) + return text + + +def prepare_w2v_data( + dict_dir, sample_rate, label, audio_paths, texts, split, data_dir +): + data_dir.mkdir(parents=True, exist_ok=True) + shutil.copyfile( + dict_dir / f"dict.{label}.txt", + data_dir / f"dict.{label}.txt" + ) + with open(data_dir / f"{split}.tsv", "w") as f: + f.write("/\n") + for audio_path in audio_paths: + wav, sr = sf.read(audio_path) + assert sr == sample_rate, f"{sr} != sample_rate" + nsample = len(wav) + f.write(f"{audio_path}\t{nsample}\n") + with open(data_dir / f"{split}.{label}", "w") as f: + for text in texts: + text = preprocess_text(text) + f.write(f"{text}\n") + + +def run_asr(asr_dir, split, w2v_ckpt, w2v_label, res_dir): + """ + results will be saved at + {res_dir}/{ref,hypo}.word-{w2v_ckpt.filename}-{split}.txt + """ + cmd = ["python", "-m", "examples.speech_recognition.infer"] + cmd += [str(asr_dir.resolve())] + cmd += ["--task", "audio_finetuning", "--nbest", "1", "--quiet"] + cmd += ["--w2l-decoder", "viterbi", "--criterion", "ctc"] + cmd += ["--post-process", "letter", "--max-tokens", "4000000"] + cmd += ["--path", str(w2v_ckpt.resolve()), "--labels", w2v_label] + cmd += ["--gen-subset", split, "--results-path", str(res_dir.resolve())] + + print(f"running cmd:\n{' '.join(cmd)}") + subprocess.run(cmd, check=True) + + +def compute_error_rate(hyp_wrd_path, ref_wrd_path, unit="word"): + """each line is " (None-)" """ + tokenize_line = { + "word": lambda x: re.sub(r" \(.*\)$", "", x.rstrip()).split(), + "char": lambda x: list(re.sub(r" \(.*\)$", "", x.rstrip())) + }.get(unit) + if tokenize_line is None: + raise ValueError(f"{unit} not supported") + + inds = [int(re.sub(r"\D*(\d*)\D*", r"\1", line)) + for line in open(hyp_wrd_path)] + hyps = [tokenize_line(line) for line in open(hyp_wrd_path)] + refs = [tokenize_line(line) for line in open(ref_wrd_path)] + assert(len(hyps) == len(refs)) + err_rates = [ + editdistance.eval(hyp, ref) / len(ref) for hyp, ref in zip(hyps, refs) + ] + ind_to_err_rates = {i: e for i, e in zip(inds, err_rates)} + return ind_to_err_rates + + +def main(args): + samples = load_tsv_to_dicts(args.raw_manifest) + ids = [ + sample[args.id_header] if args.id_header else "" for sample in samples + ] + audio_paths = [sample[args.audio_header] for sample in samples] + texts = [sample[args.text_header] for sample in samples] + + prepare_w2v_data( + args.w2v_dict_dir, + args.w2v_sample_rate, + args.w2v_label, + audio_paths, + texts, + args.split, + args.asr_dir + ) + run_asr(args.asr_dir, args.split, args.w2v_ckpt, args.w2v_label, args.asr_dir) + ind_to_err_rates = compute_error_rate( + args.asr_dir / f"hypo.word-{args.w2v_ckpt.name}-{args.split}.txt", + args.asr_dir / f"ref.word-{args.w2v_ckpt.name}-{args.split}.txt", + args.err_unit, + ) + + uer_path = args.asr_dir / f"uer_{args.err_unit}.{args.split}.tsv" + with open(uer_path, "w") as f: + f.write("id\taudio\tuer\n") + for ind, (id_, audio_path) in enumerate(zip(ids, audio_paths)): + f.write(f"{id_}\t{audio_path}\t{ind_to_err_rates[ind]:.4f}\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--raw-manifest", required=True, type=Path) + parser.add_argument("--asr-dir", required=True, type=Path) + parser.add_argument("--id-header", default="id", type=str) + parser.add_argument("--audio-header", default="audio", type=str) + parser.add_argument("--text-header", default="src_text", type=str) + parser.add_argument("--split", default="raw", type=str) + parser.add_argument("--w2v-ckpt", required=True, type=Path) + parser.add_argument("--w2v-dict-dir", required=True, type=Path) + parser.add_argument("--w2v-sample-rate", default=16000, type=int) + parser.add_argument("--w2v-label", default="ltr", type=str) + parser.add_argument("--err-unit", default="word", type=str) + args = parser.parse_args() + + main(args) diff --git a/fairseq/examples/speech_synthesis/evaluation/eval_f0.py b/fairseq/examples/speech_synthesis/evaluation/eval_f0.py new file mode 100644 index 0000000000000000000000000000000000000000..df721d683113b44957149cfc3cddaba36520a22c --- /dev/null +++ b/fairseq/examples/speech_synthesis/evaluation/eval_f0.py @@ -0,0 +1,266 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Signal processing-based evaluation using waveforms +""" +import numpy as np +import os.path as op + +import torchaudio +import tqdm +from tabulate import tabulate + +from examples.speech_synthesis.utils import ( + gross_pitch_error, voicing_decision_error, f0_frame_error +) +from examples.speech_synthesis.evaluation.eval_sp import load_eval_spec + + +def difference_function(x, n, tau_max): + """ + Compute difference function of data x. This solution is implemented directly + with Numpy fft. + + + :param x: audio data + :param n: length of data + :param tau_max: integration window size + :return: difference function + :rtype: list + """ + + x = np.array(x, np.float64) + w = x.size + tau_max = min(tau_max, w) + x_cumsum = np.concatenate((np.array([0.]), (x * x).cumsum())) + size = w + tau_max + p2 = (size // 32).bit_length() + nice_numbers = (16, 18, 20, 24, 25, 27, 30, 32) + size_pad = min(x * 2 ** p2 for x in nice_numbers if x * 2 ** p2 >= size) + fc = np.fft.rfft(x, size_pad) + conv = np.fft.irfft(fc * fc.conjugate())[:tau_max] + return x_cumsum[w:w - tau_max:-1] + x_cumsum[w] - x_cumsum[:tau_max] - \ + 2 * conv + + +def cumulative_mean_normalized_difference_function(df, n): + """ + Compute cumulative mean normalized difference function (CMND). + + :param df: Difference function + :param n: length of data + :return: cumulative mean normalized difference function + :rtype: list + """ + + # scipy method + cmn_df = df[1:] * range(1, n) / np.cumsum(df[1:]).astype(float) + return np.insert(cmn_df, 0, 1) + + +def get_pitch(cmdf, tau_min, tau_max, harmo_th=0.1): + """ + Return fundamental period of a frame based on CMND function. + + :param cmdf: Cumulative Mean Normalized Difference function + :param tau_min: minimum period for speech + :param tau_max: maximum period for speech + :param harmo_th: harmonicity threshold to determine if it is necessary to + compute pitch frequency + :return: fundamental period if there is values under threshold, 0 otherwise + :rtype: float + """ + tau = tau_min + while tau < tau_max: + if cmdf[tau] < harmo_th: + while tau + 1 < tau_max and cmdf[tau + 1] < cmdf[tau]: + tau += 1 + return tau + tau += 1 + + return 0 # if unvoiced + + +def compute_yin(sig, sr, w_len=512, w_step=256, f0_min=100, f0_max=500, + harmo_thresh=0.1): + """ + + Compute the Yin Algorithm. Return fundamental frequency and harmonic rate. + + https://github.com/NVIDIA/mellotron adaption of + https://github.com/patriceguyot/Yin + + :param sig: Audio signal (list of float) + :param sr: sampling rate (int) + :param w_len: size of the analysis window (samples) + :param w_step: size of the lag between two consecutives windows (samples) + :param f0_min: Minimum fundamental frequency that can be detected (hertz) + :param f0_max: Maximum fundamental frequency that can be detected (hertz) + :param harmo_thresh: Threshold of detection. The yalgorithmù return the + first minimum of the CMND function below this threshold. + + :returns: + + * pitches: list of fundamental frequencies, + * harmonic_rates: list of harmonic rate values for each fundamental + frequency value (= confidence value) + * argmins: minimums of the Cumulative Mean Normalized DifferenceFunction + * times: list of time of each estimation + :rtype: tuple + """ + + tau_min = int(sr / f0_max) + tau_max = int(sr / f0_min) + + # time values for each analysis window + time_scale = range(0, len(sig) - w_len, w_step) + times = [t/float(sr) for t in time_scale] + frames = [sig[t:t + w_len] for t in time_scale] + + pitches = [0.0] * len(time_scale) + harmonic_rates = [0.0] * len(time_scale) + argmins = [0.0] * len(time_scale) + + for i, frame in enumerate(frames): + # Compute YIN + df = difference_function(frame, w_len, tau_max) + cm_df = cumulative_mean_normalized_difference_function(df, tau_max) + p = get_pitch(cm_df, tau_min, tau_max, harmo_thresh) + + # Get results + if np.argmin(cm_df) > tau_min: + argmins[i] = float(sr / np.argmin(cm_df)) + if p != 0: # A pitch was found + pitches[i] = float(sr / p) + harmonic_rates[i] = cm_df[p] + else: # No pitch, but we compute a value of the harmonic rate + harmonic_rates[i] = min(cm_df) + + return pitches, harmonic_rates, argmins, times + + +def extract_f0(samples): + f0_samples = [] + for sample in tqdm.tqdm(samples): + if not op.isfile(sample["ref"]) or not op.isfile(sample["syn"]): + f0_samples.append(None) + continue + + # assume single channel + yref, sr = torchaudio.load(sample["ref"]) + ysyn, _sr = torchaudio.load(sample["syn"]) + yref, ysyn = yref[0], ysyn[0] + assert sr == _sr, f"{sr} != {_sr}" + + yref_f0 = compute_yin(yref, sr) + ysyn_f0 = compute_yin(ysyn, sr) + + f0_samples += [ + { + "ref": yref_f0, + "syn": ysyn_f0 + } + ] + + return f0_samples + + +def eval_f0_error(samples, distortion_fn): + results = [] + for sample in tqdm.tqdm(samples): + if sample is None: + results.append(None) + continue + # assume single channel + yref_f, _, _, yref_t = sample["ref"] + ysyn_f, _, _, ysyn_t = sample["syn"] + + yref_f = np.array(yref_f) + yref_t = np.array(yref_t) + ysyn_f = np.array(ysyn_f) + ysyn_t = np.array(ysyn_t) + + distortion = distortion_fn(yref_t, yref_f, ysyn_t, ysyn_f) + results.append((distortion.item(), + len(yref_f), + len(ysyn_f) + )) + return results + + +def eval_gross_pitch_error(samples): + return eval_f0_error(samples, gross_pitch_error) + + +def eval_voicing_decision_error(samples): + return eval_f0_error(samples, voicing_decision_error) + + +def eval_f0_frame_error(samples): + return eval_f0_error(samples, f0_frame_error) + + +def print_results(results, show_bin): + results = np.array(list(filter(lambda x: x is not None, results))) + + np.set_printoptions(precision=3) + + def _print_result(results): + res = { + "nutt": len(results), + "error": results[:, 0].mean(), + "std": results[:, 0].std(), + "dur_ref": int(results[:, 1].sum()), + "dur_syn": int(results[:, 2].sum()), + } + print(tabulate([res.values()], res.keys(), floatfmt=".4f")) + + print(">>>> ALL") + _print_result(results) + + if show_bin: + edges = [0, 200, 400, 600, 800, 1000, 2000, 4000] + for i in range(1, len(edges)): + mask = np.logical_and(results[:, 1] >= edges[i-1], + results[:, 1] < edges[i]) + if not mask.any(): + continue + bin_results = results[mask] + print(f">>>> ({edges[i-1]}, {edges[i]})") + _print_result(bin_results) + + +def main(eval_f0, gpe, vde, ffe, show_bin): + samples = load_eval_spec(eval_f0) + if gpe or vde or ffe: + f0_samples = extract_f0(samples) + + if gpe: + print("===== Evaluate Gross Pitch Error =====") + results = eval_gross_pitch_error(f0_samples) + print_results(results, show_bin) + if vde: + print("===== Evaluate Voicing Decision Error =====") + results = eval_voicing_decision_error(f0_samples) + print_results(results, show_bin) + if ffe: + print("===== Evaluate F0 Frame Error =====") + results = eval_f0_frame_error(f0_samples) + print_results(results, show_bin) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("eval_f0") + parser.add_argument("--gpe", action="store_true") + parser.add_argument("--vde", action="store_true") + parser.add_argument("--ffe", action="store_true") + parser.add_argument("--show-bin", action="store_true") + args = parser.parse_args() + + main(args.eval_f0, args.gpe, args.vde, args.ffe, args.show_bin) diff --git a/fairseq/examples/speech_synthesis/evaluation/eval_sp.py b/fairseq/examples/speech_synthesis/evaluation/eval_sp.py new file mode 100644 index 0000000000000000000000000000000000000000..702c4980389624f788abc0b42cdf54757a52512f --- /dev/null +++ b/fairseq/examples/speech_synthesis/evaluation/eval_sp.py @@ -0,0 +1,131 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +""" +Signal processing-based evaluation using waveforms +""" + +import csv +import numpy as np +import os.path as op + +import torch +import tqdm +from tabulate import tabulate +import torchaudio + +from examples.speech_synthesis.utils import batch_mel_spectral_distortion +from fairseq.tasks.text_to_speech import batch_mel_cepstral_distortion + + +def load_eval_spec(path): + with open(path) as f: + reader = csv.DictReader(f, delimiter='\t') + samples = list(reader) + return samples + + +def eval_distortion(samples, distortion_fn, device="cuda"): + nmiss = 0 + results = [] + for sample in tqdm.tqdm(samples): + if not op.isfile(sample["ref"]) or not op.isfile(sample["syn"]): + nmiss += 1 + results.append(None) + continue + # assume single channel + yref, sr = torchaudio.load(sample["ref"]) + ysyn, _sr = torchaudio.load(sample["syn"]) + yref, ysyn = yref[0].to(device), ysyn[0].to(device) + assert sr == _sr, f"{sr} != {_sr}" + + distortion, extra = distortion_fn([yref], [ysyn], sr, None)[0] + _, _, _, _, _, pathmap = extra + nins = torch.sum(pathmap.sum(dim=1) - 1) # extra frames in syn + ndel = torch.sum(pathmap.sum(dim=0) - 1) # missing frames from syn + results.append( + (distortion.item(), # path distortion + pathmap.size(0), # yref num frames + pathmap.size(1), # ysyn num frames + pathmap.sum().item(), # path length + nins.item(), # insertion + ndel.item(), # deletion + ) + ) + return results + + +def eval_mel_cepstral_distortion(samples, device="cuda"): + return eval_distortion(samples, batch_mel_cepstral_distortion, device) + + +def eval_mel_spectral_distortion(samples, device="cuda"): + return eval_distortion(samples, batch_mel_spectral_distortion, device) + + +def print_results(results, show_bin): + results = np.array(list(filter(lambda x: x is not None, results))) + + np.set_printoptions(precision=3) + + def _print_result(results): + dist, dur_ref, dur_syn, dur_ali, nins, ndel = results.sum(axis=0) + res = { + "nutt": len(results), + "dist": dist, + "dur_ref": int(dur_ref), + "dur_syn": int(dur_syn), + "dur_ali": int(dur_ali), + "dist_per_ref_frm": dist/dur_ref, + "dist_per_syn_frm": dist/dur_syn, + "dist_per_ali_frm": dist/dur_ali, + "ins": nins/dur_ref, + "del": ndel/dur_ref, + } + print(tabulate( + [res.values()], + res.keys(), + floatfmt=".4f" + )) + + print(">>>> ALL") + _print_result(results) + + if show_bin: + edges = [0, 200, 400, 600, 800, 1000, 2000, 4000] + for i in range(1, len(edges)): + mask = np.logical_and(results[:, 1] >= edges[i-1], + results[:, 1] < edges[i]) + if not mask.any(): + continue + bin_results = results[mask] + print(f">>>> ({edges[i-1]}, {edges[i]})") + _print_result(bin_results) + + +def main(eval_spec, mcd, msd, show_bin): + samples = load_eval_spec(eval_spec) + device = "cpu" + if mcd: + print("===== Evaluate Mean Cepstral Distortion =====") + results = eval_mel_cepstral_distortion(samples, device) + print_results(results, show_bin) + if msd: + print("===== Evaluate Mean Spectral Distortion =====") + results = eval_mel_spectral_distortion(samples, device) + print_results(results, show_bin) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("eval_spec") + parser.add_argument("--mcd", action="store_true") + parser.add_argument("--msd", action="store_true") + parser.add_argument("--show-bin", action="store_true") + args = parser.parse_args() + + main(args.eval_spec, args.mcd, args.msd, args.show_bin) diff --git a/fairseq/examples/speech_synthesis/evaluation/get_eval_manifest.py b/fairseq/examples/speech_synthesis/evaluation/get_eval_manifest.py new file mode 100644 index 0000000000000000000000000000000000000000..44b3685bb221a5bd4522668465902fb1d2eb40ec --- /dev/null +++ b/fairseq/examples/speech_synthesis/evaluation/get_eval_manifest.py @@ -0,0 +1,64 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import csv +from pathlib import Path + + +def main(args): + """ + `uid syn ref text` + """ + in_root = Path(args.generation_root).resolve() + ext = args.audio_format + with open(args.audio_manifest) as f, open(args.output_path, "w") as f_out: + reader = csv.DictReader( + f, delimiter="\t", quotechar=None, doublequote=False, + lineterminator="\n", quoting=csv.QUOTE_NONE + ) + header = ["id", "syn", "ref", "text", "speaker"] + f_out.write("\t".join(header) + "\n") + for row in reader: + dir_name = f"{ext}_{args.sample_rate}hz_{args.vocoder}" + id_ = row["id"] + syn = (in_root / dir_name / f"{id_}.{ext}").as_posix() + ref = row["audio"] + if args.use_resynthesized_target: + ref = (in_root / f"{dir_name}_tgt" / f"{id_}.{ext}").as_posix() + if args.eval_target: + syn = row["audio"] + sample = [id_, syn, ref, row["tgt_text"], row["speaker"]] + f_out.write("\t".join(sample) + "\n") + print(f"wrote evaluation file to {args.output_path}") + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument( + "--generation-root", help="output directory for generate_waveform.py" + ) + parser.add_argument( + "--audio-manifest", + help="used to determine the original utterance ID and text" + ) + parser.add_argument( + "--output-path", help="path to output evaluation spec file" + ) + parser.add_argument( + "--use-resynthesized-target", action="store_true", + help="use resynthesized reference instead of the original audio" + ) + parser.add_argument( + "--eval-target", action="store_true", + help="evaluate reference instead of model prediction" + ) + parser.add_argument("--vocoder", type=str, default="griffin_lim") + parser.add_argument("--sample-rate", type=int, default=22_050) + parser.add_argument("--audio-format", type=str, default="wav") + args = parser.parse_args() + + main(args) diff --git a/fairseq/examples/speech_synthesis/generate_waveform.py b/fairseq/examples/speech_synthesis/generate_waveform.py new file mode 100644 index 0000000000000000000000000000000000000000..3b56190dbe7bbce72992e3a547415899df2f18db --- /dev/null +++ b/fairseq/examples/speech_synthesis/generate_waveform.py @@ -0,0 +1,192 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import ast +import logging +import matplotlib.pyplot as plt +import numpy as np +from pathlib import Path +import soundfile as sf +import sys +import torch +import torchaudio + +from fairseq import checkpoint_utils, options, tasks, utils +from fairseq.logging import progress_bar +from fairseq.tasks.text_to_speech import plot_tts_output +from fairseq.data.audio.text_to_speech_dataset import TextToSpeechDataset + + +logging.basicConfig() +logging.root.setLevel(logging.INFO) +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def make_parser(): + parser = options.get_speech_generation_parser() + parser.add_argument("--dump-features", action="store_true") + parser.add_argument("--dump-waveforms", action="store_true") + parser.add_argument("--dump-attentions", action="store_true") + parser.add_argument("--dump-eos-probs", action="store_true") + parser.add_argument("--dump-plots", action="store_true") + parser.add_argument("--dump-target", action="store_true") + parser.add_argument("--output-sample-rate", default=22050, type=int) + parser.add_argument("--teacher-forcing", action="store_true") + parser.add_argument( + "--audio-format", type=str, default="wav", choices=["wav", "flac"] + ) + return parser + + +def postprocess_results( + dataset: TextToSpeechDataset, sample, hypos, resample_fn, dump_target +): + def to_np(x): + return None if x is None else x.detach().cpu().numpy() + + sample_ids = [dataset.ids[i] for i in sample["id"].tolist()] + texts = sample["src_texts"] if "src_texts" in sample else [""] * len(hypos) + attns = [to_np(hypo["attn"]) for hypo in hypos] + eos_probs = [to_np(hypo.get("eos_prob", None)) for hypo in hypos] + feat_preds = [to_np(hypo["feature"]) for hypo in hypos] + wave_preds = [to_np(resample_fn(h["waveform"])) for h in hypos] + if dump_target: + feat_targs = [to_np(hypo["targ_feature"]) for hypo in hypos] + wave_targs = [to_np(resample_fn(h["targ_waveform"])) for h in hypos] + else: + feat_targs = [None for _ in hypos] + wave_targs = [None for _ in hypos] + + return zip(sample_ids, texts, attns, eos_probs, feat_preds, wave_preds, + feat_targs, wave_targs) + + +def dump_result( + is_na_model, + args, + vocoder, + sample_id, + text, + attn, + eos_prob, + feat_pred, + wave_pred, + feat_targ, + wave_targ, +): + sample_rate = args.output_sample_rate + out_root = Path(args.results_path) + if args.dump_features: + feat_dir = out_root / "feat" + feat_dir.mkdir(exist_ok=True, parents=True) + np.save(feat_dir / f"{sample_id}.npy", feat_pred) + if args.dump_target: + feat_tgt_dir = out_root / "feat_tgt" + feat_tgt_dir.mkdir(exist_ok=True, parents=True) + np.save(feat_tgt_dir / f"{sample_id}.npy", feat_targ) + if args.dump_attentions: + attn_dir = out_root / "attn" + attn_dir.mkdir(exist_ok=True, parents=True) + np.save(attn_dir / f"{sample_id}.npy", attn.numpy()) + if args.dump_eos_probs and not is_na_model: + eos_dir = out_root / "eos" + eos_dir.mkdir(exist_ok=True, parents=True) + np.save(eos_dir / f"{sample_id}.npy", eos_prob) + + if args.dump_plots: + images = [feat_pred.T] if is_na_model else [feat_pred.T, attn] + names = ["output"] if is_na_model else ["output", "alignment"] + if feat_targ is not None: + images = [feat_targ.T] + images + names = [f"target (idx={sample_id})"] + names + if is_na_model: + plot_tts_output(images, names, attn, "alignment", suptitle=text) + else: + plot_tts_output(images, names, eos_prob, "eos prob", suptitle=text) + plot_dir = out_root / "plot" + plot_dir.mkdir(exist_ok=True, parents=True) + plt.savefig(plot_dir / f"{sample_id}.png") + plt.close() + + if args.dump_waveforms: + ext = args.audio_format + if wave_pred is not None: + wav_dir = out_root / f"{ext}_{sample_rate}hz_{vocoder}" + wav_dir.mkdir(exist_ok=True, parents=True) + sf.write(wav_dir / f"{sample_id}.{ext}", wave_pred, sample_rate) + if args.dump_target and wave_targ is not None: + wav_tgt_dir = out_root / f"{ext}_{sample_rate}hz_{vocoder}_tgt" + wav_tgt_dir.mkdir(exist_ok=True, parents=True) + sf.write(wav_tgt_dir / f"{sample_id}.{ext}", wave_targ, sample_rate) + + +def main(args): + assert(args.dump_features or args.dump_waveforms or args.dump_attentions + or args.dump_eos_probs or args.dump_plots) + if args.max_tokens is None and args.batch_size is None: + args.max_tokens = 8000 + logger.info(args) + + use_cuda = torch.cuda.is_available() and not args.cpu + task = tasks.setup_task(args) + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( + [args.path], + task=task, + arg_overrides=ast.literal_eval(args.model_overrides), + ) + model = models[0].cuda() if use_cuda else models[0] + # use the original n_frames_per_step + task.args.n_frames_per_step = saved_cfg.task.n_frames_per_step + task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task) + + data_cfg = task.data_cfg + sample_rate = data_cfg.config.get("features", {}).get("sample_rate", 22050) + resample_fn = { + False: lambda x: x, + True: lambda x: torchaudio.sox_effects.apply_effects_tensor( + x.detach().cpu().unsqueeze(0), sample_rate, + [['rate', str(args.output_sample_rate)]] + )[0].squeeze(0) + }.get(args.output_sample_rate != sample_rate) + if args.output_sample_rate != sample_rate: + logger.info(f"resampling to {args.output_sample_rate}Hz") + + generator = task.build_generator([model], args) + itr = task.get_batch_iterator( + dataset=task.dataset(args.gen_subset), + max_tokens=args.max_tokens, + max_sentences=args.batch_size, + max_positions=(sys.maxsize, sys.maxsize), + ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=args.required_batch_size_multiple, + num_shards=args.num_shards, + shard_id=args.shard_id, + num_workers=args.num_workers, + data_buffer_size=args.data_buffer_size, + ).next_epoch_itr(shuffle=False) + + Path(args.results_path).mkdir(exist_ok=True, parents=True) + is_na_model = getattr(model, "NON_AUTOREGRESSIVE", False) + dataset = task.dataset(args.gen_subset) + vocoder = task.args.vocoder + with progress_bar.build_progress_bar(args, itr) as t: + for sample in t: + sample = utils.move_to_cuda(sample) if use_cuda else sample + hypos = generator.generate(model, sample, has_targ=args.dump_target) + for result in postprocess_results( + dataset, sample, hypos, resample_fn, args.dump_target + ): + dump_result(is_na_model, args, vocoder, *result) + + +def cli_main(): + parser = make_parser() + args = options.parse_args_and_arch(parser) + main(args) + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq/examples/speech_synthesis/preprocessing/__init__.py b/fairseq/examples/speech_synthesis/preprocessing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6264236915a7269a4d920ee8213004374dd86a9a --- /dev/null +++ b/fairseq/examples/speech_synthesis/preprocessing/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/fairseq/examples/speech_synthesis/preprocessing/denoise_and_vad_audio.py b/fairseq/examples/speech_synthesis/preprocessing/denoise_and_vad_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..4e13b38a5d3fb44dd3969e6afcb8f202274ee3b7 --- /dev/null +++ b/fairseq/examples/speech_synthesis/preprocessing/denoise_and_vad_audio.py @@ -0,0 +1,204 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import os +import csv +import tempfile +from collections import defaultdict +from pathlib import Path + +import torchaudio +try: + import webrtcvad +except ImportError: + raise ImportError("Please install py-webrtcvad: pip install webrtcvad") +import pandas as pd +from tqdm import tqdm + +from examples.speech_synthesis.preprocessing.denoiser.pretrained import master64 +import examples.speech_synthesis.preprocessing.denoiser.utils as utils +from examples.speech_synthesis.preprocessing.vad import ( + frame_generator, vad_collector, read_wave, write_wave, FS_MS, THRESHOLD, + SCALE +) +from examples.speech_to_text.data_utils import save_df_to_tsv + + +log = logging.getLogger(__name__) + +PATHS = ["after_denoise", "after_vad"] +MIN_T = 0.05 + + +def generate_tmp_filename(extension="txt"): + return tempfile._get_default_tempdir() + "/" + \ + next(tempfile._get_candidate_names()) + "." + extension + + +def convert_sr(inpath, sr, output_path=None): + if not output_path: + output_path = generate_tmp_filename("wav") + cmd = f"sox {inpath} -r {sr} {output_path}" + os.system(cmd) + return output_path + + +def apply_vad(vad, inpath): + audio, sample_rate = read_wave(inpath) + frames = frame_generator(FS_MS, audio, sample_rate) + frames = list(frames) + segments = vad_collector(sample_rate, FS_MS, 300, vad, frames) + merge_segments = list() + timestamp_start = 0.0 + timestamp_end = 0.0 + # removing start, end, and long sequences of sils + for i, segment in enumerate(segments): + merge_segments.append(segment[0]) + if i and timestamp_start: + sil_duration = segment[1] - timestamp_end + if sil_duration > THRESHOLD: + merge_segments.append(int(THRESHOLD / SCALE) * (b'\x00')) + else: + merge_segments.append(int((sil_duration / SCALE)) * (b'\x00')) + timestamp_start = segment[1] + timestamp_end = segment[2] + segment = b''.join(merge_segments) + return segment, sample_rate + + +def write(wav, filename, sr=16_000): + # Normalize audio if it prevents clipping + wav = wav / max(wav.abs().max().item(), 1) + torchaudio.save(filename, wav.cpu(), sr, encoding="PCM_S", + bits_per_sample=16) + + +def process(args): + # making sure we are requested either denoise or vad + if not args.denoise and not args.vad: + log.error("No denoise or vad is requested.") + return + + log.info("Creating out directories...") + if args.denoise: + out_denoise = Path(args.output_dir).absolute().joinpath(PATHS[0]) + out_denoise.mkdir(parents=True, exist_ok=True) + if args.vad: + out_vad = Path(args.output_dir).absolute().joinpath(PATHS[1]) + out_vad.mkdir(parents=True, exist_ok=True) + + log.info("Loading pre-trained speech enhancement model...") + model = master64().to(args.device) + + log.info("Building the VAD model...") + vad = webrtcvad.Vad(int(args.vad_agg_level)) + + # preparing the output dict + output_dict = defaultdict(list) + + log.info(f"Parsing input manifest: {args.audio_manifest}") + with open(args.audio_manifest, "r") as f: + manifest_dict = csv.DictReader(f, delimiter="\t") + for row in tqdm(manifest_dict): + filename = str(row["audio"]) + + final_output = filename + keep_sample = True + n_frames = row["n_frames"] + snr = -1 + if args.denoise: + output_path_denoise = out_denoise.joinpath(Path(filename).name) + # convert to 16khz in case we use a differet sr + tmp_path = convert_sr(final_output, 16000) + + # loading audio file and generating the enhanced version + out, sr = torchaudio.load(tmp_path) + out = out.to(args.device) + estimate = model(out) + estimate = (1 - args.dry_wet) * estimate + args.dry_wet * out + write(estimate[0], str(output_path_denoise), sr) + + snr = utils.cal_snr(out, estimate) + snr = snr.cpu().detach().numpy()[0][0] + final_output = str(output_path_denoise) + + if args.vad: + output_path_vad = out_vad.joinpath(Path(filename).name) + sr = torchaudio.info(final_output).sample_rate + if sr in [16000, 32000, 48000]: + tmp_path = final_output + elif sr < 16000: + tmp_path = convert_sr(final_output, 16000) + elif sr < 32000: + tmp_path = convert_sr(final_output, 32000) + else: + tmp_path = convert_sr(final_output, 48000) + # apply VAD + segment, sample_rate = apply_vad(vad, tmp_path) + if len(segment) < sample_rate * MIN_T: + keep_sample = False + print(( + f"WARNING: skip {filename} because it is too short " + f"after VAD ({len(segment) / sample_rate} < {MIN_T})" + )) + else: + if sample_rate != sr: + tmp_path = generate_tmp_filename("wav") + write_wave(tmp_path, segment, sample_rate) + convert_sr(tmp_path, sr, + output_path=str(output_path_vad)) + else: + write_wave(str(output_path_vad), segment, sample_rate) + final_output = str(output_path_vad) + segment, _ = torchaudio.load(final_output) + n_frames = segment.size(1) + + if keep_sample: + output_dict["id"].append(row["id"]) + output_dict["audio"].append(final_output) + output_dict["n_frames"].append(n_frames) + output_dict["tgt_text"].append(row["tgt_text"]) + output_dict["speaker"].append(row["speaker"]) + output_dict["src_text"].append(row["src_text"]) + output_dict["snr"].append(snr) + + out_tsv_path = Path(args.output_dir) / Path(args.audio_manifest).name + log.info(f"Saving manifest to {out_tsv_path.as_posix()}") + save_df_to_tsv(pd.DataFrame.from_dict(output_dict), out_tsv_path) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--audio-manifest", "-i", required=True, + type=str, help="path to the input manifest.") + parser.add_argument( + "--output-dir", "-o", required=True, type=str, + help="path to the output dir. it will contain files after denoising and" + " vad" + ) + parser.add_argument("--vad-agg-level", "-a", type=int, default=2, + help="the aggresive level of the vad [0-3].") + parser.add_argument( + "--dry-wet", "-dw", type=float, default=0.01, + help="the level of linear interpolation between noisy and enhanced " + "files." + ) + parser.add_argument( + "--device", "-d", type=str, default="cpu", + help="the device to be used for the speech enhancement model: " + "cpu | cuda." + ) + parser.add_argument("--denoise", action="store_true", + help="apply a denoising") + parser.add_argument("--vad", action="store_true", help="apply a VAD") + args = parser.parse_args() + + process(args) + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/speech_synthesis/preprocessing/denoiser/__init__.py b/fairseq/examples/speech_synthesis/preprocessing/denoiser/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6264236915a7269a4d920ee8213004374dd86a9a --- /dev/null +++ b/fairseq/examples/speech_synthesis/preprocessing/denoiser/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/fairseq/examples/speech_synthesis/preprocessing/denoiser/demucs.py b/fairseq/examples/speech_synthesis/preprocessing/denoiser/demucs.py new file mode 100644 index 0000000000000000000000000000000000000000..3f70e73d6a37d32e05b6cf0e87f42e13c467cd52 --- /dev/null +++ b/fairseq/examples/speech_synthesis/preprocessing/denoiser/demucs.py @@ -0,0 +1,473 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# author: adefossez + +import math +import time + +import torch as th +from torch import nn +from torch.nn import functional as F + +from .resample import downsample2, upsample2 +from .utils import capture_init + + +class BLSTM(nn.Module): + def __init__(self, dim, layers=2, bi=True): + super().__init__() + klass = nn.LSTM + self.lstm = klass( + bidirectional=bi, num_layers=layers, hidden_size=dim, input_size=dim + ) + self.linear = None + if bi: + self.linear = nn.Linear(2 * dim, dim) + + def forward(self, x, hidden=None): + x, hidden = self.lstm(x, hidden) + if self.linear: + x = self.linear(x) + return x, hidden + + +def rescale_conv(conv, reference): + std = conv.weight.std().detach() + scale = (std / reference)**0.5 + conv.weight.data /= scale + if conv.bias is not None: + conv.bias.data /= scale + + +def rescale_module(module, reference): + for sub in module.modules(): + if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): + rescale_conv(sub, reference) + + +class Demucs(nn.Module): + """ + Demucs speech enhancement model. + Args: + - chin (int): number of input channels. + - chout (int): number of output channels. + - hidden (int): number of initial hidden channels. + - depth (int): number of layers. + - kernel_size (int): kernel size for each layer. + - stride (int): stride for each layer. + - causal (bool): if false, uses BiLSTM instead of LSTM. + - resample (int): amount of resampling to apply to the input/output. + Can be one of 1, 2 or 4. + - growth (float): number of channels is multiplied by this for every layer. + - max_hidden (int): maximum number of channels. Can be useful to + control the size/speed of the model. + - normalize (bool): if true, normalize the input. + - glu (bool): if true uses GLU instead of ReLU in 1x1 convolutions. + - rescale (float): controls custom weight initialization. + See https://arxiv.org/abs/1911.13254. + - floor (float): stability flooring when normalizing. + + """ + @capture_init + def __init__(self, + chin=1, + chout=1, + hidden=48, + depth=5, + kernel_size=8, + stride=4, + causal=True, + resample=4, + growth=2, + max_hidden=10_000, + normalize=True, + glu=True, + rescale=0.1, + floor=1e-3): + + super().__init__() + if resample not in [1, 2, 4]: + raise ValueError("Resample should be 1, 2 or 4.") + + self.chin = chin + self.chout = chout + self.hidden = hidden + self.depth = depth + self.kernel_size = kernel_size + self.stride = stride + self.causal = causal + self.floor = floor + self.resample = resample + self.normalize = normalize + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + activation = nn.GLU(1) if glu else nn.ReLU() + ch_scale = 2 if glu else 1 + + for index in range(depth): + encode = [] + encode += [ + nn.Conv1d(chin, hidden, kernel_size, stride), + nn.ReLU(), + nn.Conv1d(hidden, hidden * ch_scale, 1), activation, + ] + self.encoder.append(nn.Sequential(*encode)) + + decode = [] + decode += [ + nn.Conv1d(hidden, ch_scale * hidden, 1), activation, + nn.ConvTranspose1d(hidden, chout, kernel_size, stride), + ] + if index > 0: + decode.append(nn.ReLU()) + self.decoder.insert(0, nn.Sequential(*decode)) + chout = hidden + chin = hidden + hidden = min(int(growth * hidden), max_hidden) + + self.lstm = BLSTM(chin, bi=not causal) + if rescale: + rescale_module(self, reference=rescale) + + def valid_length(self, length): + """ + Return the nearest valid length to use with the model so that + there is no time steps left over in a convolutions, e.g. for all + layers, size of the input - kernel_size % stride = 0. + + If the mixture has a valid length, the estimated sources + will have exactly the same length. + """ + length = math.ceil(length * self.resample) + for _ in range(self.depth): + length = math.ceil((length - self.kernel_size) / self.stride) + 1 + length = max(length, 1) + for _ in range(self.depth): + length = (length - 1) * self.stride + self.kernel_size + length = int(math.ceil(length / self.resample)) + return int(length) + + @property + def total_stride(self): + return self.stride ** self.depth // self.resample + + def forward(self, mix): + if mix.dim() == 2: + mix = mix.unsqueeze(1) + + if self.normalize: + mono = mix.mean(dim=1, keepdim=True) + std = mono.std(dim=-1, keepdim=True) + mix = mix / (self.floor + std) + else: + std = 1 + length = mix.shape[-1] + x = mix + x = F.pad(x, (0, self.valid_length(length) - length)) + if self.resample == 2: + x = upsample2(x) + elif self.resample == 4: + x = upsample2(x) + x = upsample2(x) + skips = [] + for encode in self.encoder: + x = encode(x) + skips.append(x) + x = x.permute(2, 0, 1) + x, _ = self.lstm(x) + x = x.permute(1, 2, 0) + for decode in self.decoder: + skip = skips.pop(-1) + x = x + skip[..., :x.shape[-1]] + x = decode(x) + if self.resample == 2: + x = downsample2(x) + elif self.resample == 4: + x = downsample2(x) + x = downsample2(x) + + x = x[..., :length] + return std * x + + +def fast_conv(conv, x): + """ + Faster convolution evaluation if either kernel size is 1 + or length of sequence is 1. + """ + batch, chin, length = x.shape + chout, chin, kernel = conv.weight.shape + assert batch == 1 + if kernel == 1: + x = x.view(chin, length) + out = th.addmm(conv.bias.view(-1, 1), + conv.weight.view(chout, chin), x) + elif length == kernel: + x = x.view(chin * kernel, 1) + out = th.addmm(conv.bias.view(-1, 1), + conv.weight.view(chout, chin * kernel), x) + else: + out = conv(x) + return out.view(batch, chout, -1) + + +class DemucsStreamer: + """ + Streaming implementation for Demucs. It supports being fed with any amount + of audio at a time. You will get back as much audio as possible at that + point. + + Args: + - demucs (Demucs): Demucs model. + - dry (float): amount of dry (e.g. input) signal to keep. 0 is maximum + noise removal, 1 just returns the input signal. Small values > 0 + allows to limit distortions. + - num_frames (int): number of frames to process at once. Higher values + will increase overall latency but improve the real time factor. + - resample_lookahead (int): extra lookahead used for the resampling. + - resample_buffer (int): size of the buffer of previous inputs/outputs + kept for resampling. + """ + def __init__(self, demucs, + dry=0, + num_frames=1, + resample_lookahead=64, + resample_buffer=256): + device = next(iter(demucs.parameters())).device + self.demucs = demucs + self.lstm_state = None + self.conv_state = None + self.dry = dry + self.resample_lookahead = resample_lookahead + resample_buffer = min(demucs.total_stride, resample_buffer) + self.resample_buffer = resample_buffer + self.frame_length = demucs.valid_length(1) + \ + demucs.total_stride * (num_frames - 1) + self.total_length = self.frame_length + self.resample_lookahead + self.stride = demucs.total_stride * num_frames + self.resample_in = th.zeros(demucs.chin, resample_buffer, device=device) + self.resample_out = th.zeros( + demucs.chin, resample_buffer, device=device + ) + + self.frames = 0 + self.total_time = 0 + self.variance = 0 + self.pending = th.zeros(demucs.chin, 0, device=device) + + bias = demucs.decoder[0][2].bias + weight = demucs.decoder[0][2].weight + chin, chout, kernel = weight.shape + self._bias = bias.view(-1, 1).repeat(1, kernel).view(-1, 1) + self._weight = weight.permute(1, 2, 0).contiguous() + + def reset_time_per_frame(self): + self.total_time = 0 + self.frames = 0 + + @property + def time_per_frame(self): + return self.total_time / self.frames + + def flush(self): + """ + Flush remaining audio by padding it with zero. Call this + when you have no more input and want to get back the last chunk of audio. + """ + pending_length = self.pending.shape[1] + padding = th.zeros( + self.demucs.chin, self.total_length, device=self.pending.device + ) + out = self.feed(padding) + return out[:, :pending_length] + + def feed(self, wav): + """ + Apply the model to mix using true real time evaluation. + Normalization is done online as is the resampling. + """ + begin = time.time() + demucs = self.demucs + resample_buffer = self.resample_buffer + stride = self.stride + resample = demucs.resample + + if wav.dim() != 2: + raise ValueError("input wav should be two dimensional.") + chin, _ = wav.shape + if chin != demucs.chin: + raise ValueError(f"Expected {demucs.chin} channels, got {chin}") + + self.pending = th.cat([self.pending, wav], dim=1) + outs = [] + while self.pending.shape[1] >= self.total_length: + self.frames += 1 + frame = self.pending[:, :self.total_length] + dry_signal = frame[:, :stride] + if demucs.normalize: + mono = frame.mean(0) + variance = (mono**2).mean() + self.variance = variance / self.frames + \ + (1 - 1 / self.frames) * self.variance + frame = frame / (demucs.floor + math.sqrt(self.variance)) + frame = th.cat([self.resample_in, frame], dim=-1) + self.resample_in[:] = frame[:, stride - resample_buffer:stride] + + if resample == 4: + frame = upsample2(upsample2(frame)) + elif resample == 2: + frame = upsample2(frame) + # remove pre sampling buffer + frame = frame[:, resample * resample_buffer:] + # remove extra samples after window + frame = frame[:, :resample * self.frame_length] + + out, extra = self._separate_frame(frame) + padded_out = th.cat([self.resample_out, out, extra], 1) + self.resample_out[:] = out[:, -resample_buffer:] + if resample == 4: + out = downsample2(downsample2(padded_out)) + elif resample == 2: + out = downsample2(padded_out) + else: + out = padded_out + + out = out[:, resample_buffer // resample:] + out = out[:, :stride] + + if demucs.normalize: + out *= math.sqrt(self.variance) + out = self.dry * dry_signal + (1 - self.dry) * out + outs.append(out) + self.pending = self.pending[:, stride:] + + self.total_time += time.time() - begin + if outs: + out = th.cat(outs, 1) + else: + out = th.zeros(chin, 0, device=wav.device) + return out + + def _separate_frame(self, frame): + demucs = self.demucs + skips = [] + next_state = [] + first = self.conv_state is None + stride = self.stride * demucs.resample + x = frame[None] + for idx, encode in enumerate(demucs.encoder): + stride //= demucs.stride + length = x.shape[2] + if idx == demucs.depth - 1: + # This is sligthly faster for the last conv + x = fast_conv(encode[0], x) + x = encode[1](x) + x = fast_conv(encode[2], x) + x = encode[3](x) + else: + if not first: + prev = self.conv_state.pop(0) + prev = prev[..., stride:] + tgt = (length - demucs.kernel_size) // demucs.stride + 1 + missing = tgt - prev.shape[-1] + offset = length - demucs.kernel_size - \ + demucs.stride * (missing - 1) + x = x[..., offset:] + x = encode[1](encode[0](x)) + x = fast_conv(encode[2], x) + x = encode[3](x) + if not first: + x = th.cat([prev, x], -1) + next_state.append(x) + skips.append(x) + + x = x.permute(2, 0, 1) + x, self.lstm_state = demucs.lstm(x, self.lstm_state) + x = x.permute(1, 2, 0) + # In the following, x contains only correct samples, i.e. the one + # for which each time position is covered by two window of the upper + # layer. extra contains extra samples to the right, and is used only as + # a better padding for the online resampling. + extra = None + for idx, decode in enumerate(demucs.decoder): + skip = skips.pop(-1) + x += skip[..., :x.shape[-1]] + x = fast_conv(decode[0], x) + x = decode[1](x) + + if extra is not None: + skip = skip[..., x.shape[-1]:] + extra += skip[..., :extra.shape[-1]] + extra = decode[2](decode[1](decode[0](extra))) + x = decode[2](x) + next_state.append( + x[..., -demucs.stride:] - decode[2].bias.view(-1, 1) + ) + if extra is None: + extra = x[..., -demucs.stride:] + else: + extra[..., :demucs.stride] += next_state[-1] + x = x[..., :-demucs.stride] + + if not first: + prev = self.conv_state.pop(0) + x[..., :demucs.stride] += prev + if idx != demucs.depth - 1: + x = decode[3](x) + extra = decode[3](extra) + self.conv_state = next_state + return x[0], extra[0] + + +def test(): + import argparse + parser = argparse.ArgumentParser( + "denoiser.demucs", + description="Benchmark the streaming Demucs implementation, as well as " + "checking the delta with the offline implementation.") + parser.add_argument("--depth", default=5, type=int) + parser.add_argument("--resample", default=4, type=int) + parser.add_argument("--hidden", default=48, type=int) + parser.add_argument("--sample_rate", default=16000, type=float) + parser.add_argument("--device", default="cpu") + parser.add_argument("-t", "--num_threads", type=int) + parser.add_argument("-f", "--num_frames", type=int, default=1) + args = parser.parse_args() + if args.num_threads: + th.set_num_threads(args.num_threads) + sr = args.sample_rate + sr_ms = sr / 1000 + demucs = Demucs( + depth=args.depth, hidden=args.hidden, resample=args.resample + ).to(args.device) + x = th.randn(1, int(sr * 4)).to(args.device) + out = demucs(x[None])[0] + streamer = DemucsStreamer(demucs, num_frames=args.num_frames) + out_rt = [] + frame_size = streamer.total_length + with th.no_grad(): + while x.shape[1] > 0: + out_rt.append(streamer.feed(x[:, :frame_size])) + x = x[:, frame_size:] + frame_size = streamer.demucs.total_stride + out_rt.append(streamer.flush()) + out_rt = th.cat(out_rt, 1) + model_size = sum(p.numel() for p in demucs.parameters()) * 4 / 2**20 + initial_lag = streamer.total_length / sr_ms + tpf = 1000 * streamer.time_per_frame + print(f"model size: {model_size:.1f}MB, ", end='') + print(f"delta batch/streaming: {th.norm(out - out_rt) / th.norm(out):.2%}") + print(f"initial lag: {initial_lag:.1f}ms, ", end='') + print(f"stride: {streamer.stride * args.num_frames / sr_ms:.1f}ms") + print(f"time per frame: {tpf:.1f}ms, ", end='') + rtf = (1000 * streamer.time_per_frame) / (streamer.stride / sr_ms) + print(f"RTF: {rtf:.2f}") + print(f"Total lag with computation: {initial_lag + tpf:.1f}ms") + + +if __name__ == "__main__": + test() diff --git a/fairseq/examples/speech_synthesis/preprocessing/denoiser/pretrained.py b/fairseq/examples/speech_synthesis/preprocessing/denoiser/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..2fa846075b6872cdcc0baebca0b9acbb9ffcd287 --- /dev/null +++ b/fairseq/examples/speech_synthesis/preprocessing/denoiser/pretrained.py @@ -0,0 +1,81 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# author: adefossez + +import logging + +import torch.hub + +from .demucs import Demucs +from .utils import deserialize_model + +logger = logging.getLogger(__name__) +ROOT = "https://dl.fbaipublicfiles.com/adiyoss/denoiser/" +DNS_48_URL = ROOT + "dns48-11decc9d8e3f0998.th" +DNS_64_URL = ROOT + "dns64-a7761ff99a7d5bb6.th" +MASTER_64_URL = ROOT + "master64-8a5dfb4bb92753dd.th" + + +def _demucs(pretrained, url, **kwargs): + model = Demucs(**kwargs) + if pretrained: + state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu') + model.load_state_dict(state_dict) + return model + + +def dns48(pretrained=True): + return _demucs(pretrained, DNS_48_URL, hidden=48) + + +def dns64(pretrained=True): + return _demucs(pretrained, DNS_64_URL, hidden=64) + + +def master64(pretrained=True): + return _demucs(pretrained, MASTER_64_URL, hidden=64) + + +def add_model_flags(parser): + group = parser.add_mutually_exclusive_group(required=False) + group.add_argument( + "-m", "--model_path", help="Path to local trained model." + ) + group.add_argument( + "--dns48", action="store_true", + help="Use pre-trained real time H=48 model trained on DNS." + ) + group.add_argument( + "--dns64", action="store_true", + help="Use pre-trained real time H=64 model trained on DNS." + ) + group.add_argument( + "--master64", action="store_true", + help="Use pre-trained real time H=64 model trained on DNS and Valentini." + ) + + +def get_model(args): + """ + Load local model package or torchhub pre-trained model. + """ + if args.model_path: + logger.info("Loading model from %s", args.model_path) + pkg = torch.load(args.model_path) + model = deserialize_model(pkg) + elif args.dns64: + logger.info("Loading pre-trained real time H=64 model trained on DNS.") + model = dns64() + elif args.master64: + logger.info( + "Loading pre-trained real time H=64 model trained on DNS and Valentini." + ) + model = master64() + else: + logger.info("Loading pre-trained real time H=48 model trained on DNS.") + model = dns48() + logger.debug(model) + return model diff --git a/fairseq/examples/speech_synthesis/preprocessing/denoiser/resample.py b/fairseq/examples/speech_synthesis/preprocessing/denoiser/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..1222addc424d4f898d602009e4032907241aadfe --- /dev/null +++ b/fairseq/examples/speech_synthesis/preprocessing/denoiser/resample.py @@ -0,0 +1,79 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# author: adefossez + +import math + +import torch as th +from torch.nn import functional as F + + +def sinc(t): + """sinc. + + :param t: the input tensor + """ + return th.where(t == 0, th.tensor(1., device=t.device, dtype=t.dtype), + th.sin(t) / t) + + +def kernel_upsample2(zeros=56): + """kernel_upsample2. + + """ + win = th.hann_window(4 * zeros + 1, periodic=False) + winodd = win[1::2] + t = th.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros) + t *= math.pi + kernel = (sinc(t) * winodd).view(1, 1, -1) + return kernel + + +def upsample2(x, zeros=56): + """ + Upsampling the input by 2 using sinc interpolation. + Smith, Julius, and Phil Gossett. "A flexible sampling-rate conversion method." + ICASSP'84. IEEE International Conference on Acoustics, Speech, and Signal Processing. + Vol. 9. IEEE, 1984. + """ + *other, time = x.shape + kernel = kernel_upsample2(zeros).to(x) + out = F.conv1d(x.view(-1, 1, time), kernel, padding=zeros)[..., 1:].view( + *other, time + ) + y = th.stack([x, out], dim=-1) + return y.view(*other, -1) + + +def kernel_downsample2(zeros=56): + """kernel_downsample2. + + """ + win = th.hann_window(4 * zeros + 1, periodic=False) + winodd = win[1::2] + t = th.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros) + t.mul_(math.pi) + kernel = (sinc(t) * winodd).view(1, 1, -1) + return kernel + + +def downsample2(x, zeros=56): + """ + Downsampling the input by 2 using sinc interpolation. + Smith, Julius, and Phil Gossett. "A flexible sampling-rate conversion method." + ICASSP'84. IEEE International Conference on Acoustics, Speech, and Signal Processing. + Vol. 9. IEEE, 1984. + """ + if x.shape[-1] % 2 != 0: + x = F.pad(x, (0, 1)) + xeven = x[..., ::2] + xodd = x[..., 1::2] + *other, time = xodd.shape + kernel = kernel_downsample2(zeros).to(x) + out = xeven + F.conv1d( + xodd.view(-1, 1, time), kernel, padding=zeros + )[..., :-1].view(*other, time) + return out.view(*other, -1).mul(0.5) diff --git a/fairseq/examples/speech_synthesis/preprocessing/denoiser/utils.py b/fairseq/examples/speech_synthesis/preprocessing/denoiser/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..734d047f1bb8e3aa98c88e152eee7f91fea3d814 --- /dev/null +++ b/fairseq/examples/speech_synthesis/preprocessing/denoiser/utils.py @@ -0,0 +1,176 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# author: adefossez + +import functools +import logging +from contextlib import contextmanager +import inspect +import time + +logger = logging.getLogger(__name__) + +EPS = 1e-8 + + +def capture_init(init): + """capture_init. + + Decorate `__init__` with this, and you can then + recover the *args and **kwargs passed to it in `self._init_args_kwargs` + """ + @functools.wraps(init) + def __init__(self, *args, **kwargs): + self._init_args_kwargs = (args, kwargs) + init(self, *args, **kwargs) + + return __init__ + + +def deserialize_model(package, strict=False): + """deserialize_model. + + """ + klass = package['class'] + if strict: + model = klass(*package['args'], **package['kwargs']) + else: + sig = inspect.signature(klass) + kw = package['kwargs'] + for key in list(kw): + if key not in sig.parameters: + logger.warning("Dropping inexistant parameter %s", key) + del kw[key] + model = klass(*package['args'], **kw) + model.load_state_dict(package['state']) + return model + + +def copy_state(state): + return {k: v.cpu().clone() for k, v in state.items()} + + +def serialize_model(model): + args, kwargs = model._init_args_kwargs + state = copy_state(model.state_dict()) + return {"class": model.__class__, "args": args, "kwargs": kwargs, "state": state} + + +@contextmanager +def swap_state(model, state): + """ + Context manager that swaps the state of a model, e.g: + + # model is in old state + with swap_state(model, new_state): + # model in new state + # model back to old state + """ + old_state = copy_state(model.state_dict()) + model.load_state_dict(state) + try: + yield + finally: + model.load_state_dict(old_state) + + +def pull_metric(history, name): + out = [] + for metrics in history: + if name in metrics: + out.append(metrics[name]) + return out + + +class LogProgress: + """ + Sort of like tqdm but using log lines and not as real time. + Args: + - logger: logger obtained from `logging.getLogger`, + - iterable: iterable object to wrap + - updates (int): number of lines that will be printed, e.g. + if `updates=5`, log every 1/5th of the total length. + - total (int): length of the iterable, in case it does not support + `len`. + - name (str): prefix to use in the log. + - level: logging level (like `logging.INFO`). + """ + def __init__(self, + logger, + iterable, + updates=5, + total=None, + name="LogProgress", + level=logging.INFO): + self.iterable = iterable + self.total = total or len(iterable) + self.updates = updates + self.name = name + self.logger = logger + self.level = level + + def update(self, **infos): + self._infos = infos + + def __iter__(self): + self._iterator = iter(self.iterable) + self._index = -1 + self._infos = {} + self._begin = time.time() + return self + + def __next__(self): + self._index += 1 + try: + value = next(self._iterator) + except StopIteration: + raise + else: + return value + finally: + log_every = max(1, self.total // self.updates) + # logging is delayed by 1 it, in order to have the metrics from update + if self._index >= 1 and self._index % log_every == 0: + self._log() + + def _log(self): + self._speed = (1 + self._index) / (time.time() - self._begin) + infos = " | ".join(f"{k.capitalize()} {v}" for k, v in self._infos.items()) + if self._speed < 1e-4: + speed = "oo sec/it" + elif self._speed < 0.1: + speed = f"{1/self._speed:.1f} sec/it" + else: + speed = f"{self._speed:.1f} it/sec" + out = f"{self.name} | {self._index}/{self.total} | {speed}" + if infos: + out += " | " + infos + self.logger.log(self.level, out) + + +def colorize(text, color): + """ + Display text with some ANSI color in the terminal. + """ + code = f"\033[{color}m" + restore = "\033[0m" + return "".join([code, text, restore]) + + +def bold(text): + """ + Display text in bold in the terminal. + """ + return colorize(text, "1") + + +def cal_snr(lbl, est): + import torch + y = 10.0 * torch.log10( + torch.sum(lbl**2, dim=-1) / (torch.sum((est-lbl)**2, dim=-1) + EPS) + + EPS + ) + return y diff --git a/fairseq/examples/speech_synthesis/preprocessing/get_common_voice_audio_manifest.py b/fairseq/examples/speech_synthesis/preprocessing/get_common_voice_audio_manifest.py new file mode 100644 index 0000000000000000000000000000000000000000..a30254604311a488a1d4959f941051890ed32b2e --- /dev/null +++ b/fairseq/examples/speech_synthesis/preprocessing/get_common_voice_audio_manifest.py @@ -0,0 +1,140 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +from pathlib import Path +from collections import defaultdict +from typing import List, Dict, Tuple + +import pandas as pd +import numpy as np +import torchaudio +from tqdm import tqdm + +from examples.speech_to_text.data_utils import load_df_from_tsv, save_df_to_tsv + + +log = logging.getLogger(__name__) + +SPLITS = ["train", "dev", "test"] + + +def get_top_n( + root: Path, n_speakers: int = 10, min_n_tokens: int = 5 +) -> pd.DataFrame: + df = load_df_from_tsv(root / "validated.tsv") + df["n_tokens"] = [len(s.split()) for s in df["sentence"]] + df = df[df["n_tokens"] >= min_n_tokens] + df["n_frames"] = [ + torchaudio.info((root / "clips" / p).as_posix()).num_frames + for p in tqdm(df["path"]) + ] + df["id"] = [Path(p).stem for p in df["path"]] + total_duration_ms = df.groupby("client_id")["n_frames"].agg(["sum"]) + total_duration_ms = total_duration_ms.sort_values("sum", ascending=False) + + top_n_total_duration_ms = total_duration_ms.head(n_speakers) + top_n_client_ids = set(top_n_total_duration_ms.index.tolist()) + df_top_n = df[df["client_id"].isin(top_n_client_ids)] + return df_top_n + + +def get_splits( + df, train_split_ratio=0.99, speaker_in_all_splits=False, rand_seed=0 +) -> Tuple[Dict[str, str], List[str]]: + np.random.seed(rand_seed) + dev_split_ratio = (1. - train_split_ratio) / 3 + grouped = list(df.groupby("client_id")) + id_to_split = {} + for _, cur_df in tqdm(grouped): + cur_n_examples = len(cur_df) + if speaker_in_all_splits and cur_n_examples < 3: + continue + cur_n_train = int(cur_n_examples * train_split_ratio) + cur_n_dev = int(cur_n_examples * dev_split_ratio) + cur_n_test = cur_n_examples - cur_n_dev - cur_n_train + if speaker_in_all_splits and cur_n_dev * cur_n_test == 0: + cur_n_dev, cur_n_test = 1, 1 + cur_n_train = cur_n_examples - cur_n_dev - cur_n_test + cur_indices = cur_df.index.tolist() + cur_shuffled_indices = np.random.permutation(cur_n_examples) + cur_shuffled_indices = [cur_indices[i] for i in cur_shuffled_indices] + cur_indices_by_split = { + "train": cur_shuffled_indices[:cur_n_train], + "dev": cur_shuffled_indices[cur_n_train: cur_n_train + cur_n_dev], + "test": cur_shuffled_indices[cur_n_train + cur_n_dev:] + } + for split in SPLITS: + for i in cur_indices_by_split[split]: + id_ = df["id"].loc[i] + id_to_split[id_] = split + return id_to_split, sorted(df["client_id"].unique()) + + +def convert_to_wav(root: Path, filenames: List[str], target_sr=16_000): + out_root = root / "wav" + out_root.mkdir(exist_ok=True, parents=True) + print("Converting to WAV...") + for n in tqdm(filenames): + in_path = (root / "clips" / n).as_posix() + waveform, sr = torchaudio.load(in_path) + converted, converted_sr = torchaudio.sox_effects.apply_effects_tensor( + waveform, sr, [["rate", str(target_sr)], ["channels", "1"]] + ) + out_path = (out_root / Path(n).with_suffix(".wav").name).as_posix() + torchaudio.save(out_path, converted, converted_sr, encoding="PCM_S", + bits_per_sample=16) + + +def process(args): + data_root = Path(args.data_root).absolute() / args.lang + + # Generate TSV manifest + print("Generating manifest...") + + df_top_n = get_top_n(data_root) + id_to_split, speakers = get_splits(df_top_n) + + if args.convert_to_wav: + convert_to_wav(data_root, df_top_n["path"].tolist()) + + manifest_by_split = {split: defaultdict(list) for split in SPLITS} + for sample in tqdm(df_top_n.to_dict(orient="index").values()): + sample_id = sample["id"] + split = id_to_split[sample_id] + manifest_by_split[split]["id"].append(sample_id) + if args.convert_to_wav: + audio_path = data_root / "wav" / f"{sample_id}.wav" + else: + audio_path = data_root / "clips" / f"{sample_id}.mp3" + manifest_by_split[split]["audio"].append(audio_path.as_posix()) + manifest_by_split[split]["n_frames"].append(sample["n_frames"]) + manifest_by_split[split]["tgt_text"].append(sample["sentence"]) + manifest_by_split[split]["speaker"].append(sample["client_id"]) + manifest_by_split[split]["src_text"].append(sample["sentence"]) + + output_root = Path(args.output_manifest_root).absolute() + output_root.mkdir(parents=True, exist_ok=True) + for split in SPLITS: + save_df_to_tsv( + pd.DataFrame.from_dict(manifest_by_split[split]), + output_root / f"{split}.audio.tsv" + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--data-root", "-d", required=True, type=str) + parser.add_argument("--output-manifest-root", "-m", required=True, type=str) + parser.add_argument("--lang", "-l", required=True, type=str) + parser.add_argument("--convert-to-wav", action="store_true") + args = parser.parse_args() + + process(args) + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/speech_synthesis/preprocessing/get_feature_manifest.py b/fairseq/examples/speech_synthesis/preprocessing/get_feature_manifest.py new file mode 100644 index 0000000000000000000000000000000000000000..4a1e119b327c0ff4fac105bd7a83fcb547eb3c2d --- /dev/null +++ b/fairseq/examples/speech_synthesis/preprocessing/get_feature_manifest.py @@ -0,0 +1,262 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +from pathlib import Path +import shutil +from tempfile import NamedTemporaryFile +from collections import Counter, defaultdict + +import pandas as pd +import torchaudio +from tqdm import tqdm + +from fairseq.data.audio.audio_utils import convert_waveform +from examples.speech_to_text.data_utils import ( + create_zip, + gen_config_yaml, + gen_vocab, + get_zip_manifest, + load_tsv_to_dicts, + save_df_to_tsv +) +from examples.speech_synthesis.data_utils import ( + extract_logmel_spectrogram, extract_pitch, extract_energy, get_global_cmvn, + ipa_phonemize, get_mfa_alignment, get_unit_alignment, + get_feature_value_min_max +) + + +log = logging.getLogger(__name__) + + +def process(args): + assert "train" in args.splits + out_root = Path(args.output_root).absolute() + out_root.mkdir(exist_ok=True) + + print("Fetching data...") + audio_manifest_root = Path(args.audio_manifest_root).absolute() + samples = [] + for s in args.splits: + for e in load_tsv_to_dicts(audio_manifest_root / f"{s}.audio.tsv"): + e["split"] = s + samples.append(e) + sample_ids = [s["id"] for s in samples] + + # Get alignment info + id_to_alignment = None + if args.textgrid_zip is not None: + assert args.id_to_units_tsv is None + id_to_alignment = get_mfa_alignment( + args.textgrid_zip, sample_ids, args.sample_rate, args.hop_length + ) + elif args.id_to_units_tsv is not None: + # assume identical hop length on the unit sequence + id_to_alignment = get_unit_alignment(args.id_to_units_tsv, sample_ids) + + # Extract features and pack features into ZIP + feature_name = "logmelspec80" + zip_path = out_root / f"{feature_name}.zip" + pitch_zip_path = out_root / "pitch.zip" + energy_zip_path = out_root / "energy.zip" + gcmvn_npz_path = out_root / "gcmvn_stats.npz" + if zip_path.exists() and gcmvn_npz_path.exists(): + print(f"{zip_path} and {gcmvn_npz_path} exist.") + else: + feature_root = out_root / feature_name + feature_root.mkdir(exist_ok=True) + pitch_root = out_root / "pitch" + energy_root = out_root / "energy" + if args.add_fastspeech_targets: + pitch_root.mkdir(exist_ok=True) + energy_root.mkdir(exist_ok=True) + print("Extracting Mel spectrogram features...") + for sample in tqdm(samples): + waveform, sample_rate = torchaudio.load(sample["audio"]) + waveform, sample_rate = convert_waveform( + waveform, sample_rate, normalize_volume=args.normalize_volume, + to_sample_rate=args.sample_rate + ) + sample_id = sample["id"] + target_length = None + if id_to_alignment is not None: + a = id_to_alignment[sample_id] + target_length = sum(a.frame_durations) + if a.start_sec is not None and a.end_sec is not None: + start_frame = int(a.start_sec * sample_rate) + end_frame = int(a.end_sec * sample_rate) + waveform = waveform[:, start_frame: end_frame] + extract_logmel_spectrogram( + waveform, sample_rate, feature_root / f"{sample_id}.npy", + win_length=args.win_length, hop_length=args.hop_length, + n_fft=args.n_fft, n_mels=args.n_mels, f_min=args.f_min, + f_max=args.f_max, target_length=target_length + ) + if args.add_fastspeech_targets: + assert id_to_alignment is not None + extract_pitch( + waveform, sample_rate, pitch_root / f"{sample_id}.npy", + hop_length=args.hop_length, log_scale=True, + phoneme_durations=id_to_alignment[sample_id].frame_durations + ) + extract_energy( + waveform, energy_root / f"{sample_id}.npy", + hop_length=args.hop_length, n_fft=args.n_fft, + log_scale=True, + phoneme_durations=id_to_alignment[sample_id].frame_durations + ) + print("ZIPing features...") + create_zip(feature_root, zip_path) + get_global_cmvn(feature_root, gcmvn_npz_path) + shutil.rmtree(feature_root) + if args.add_fastspeech_targets: + create_zip(pitch_root, pitch_zip_path) + shutil.rmtree(pitch_root) + create_zip(energy_root, energy_zip_path) + shutil.rmtree(energy_root) + + print("Fetching ZIP manifest...") + audio_paths, audio_lengths = get_zip_manifest(zip_path) + pitch_paths, pitch_lengths, energy_paths, energy_lengths = [None] * 4 + if args.add_fastspeech_targets: + pitch_paths, pitch_lengths = get_zip_manifest(pitch_zip_path) + energy_paths, energy_lengths = get_zip_manifest(energy_zip_path) + # Generate TSV manifest + print("Generating manifest...") + id_to_cer = None + if args.cer_threshold is not None: + assert Path(args.cer_tsv_path).is_file() + id_to_cer = { + x["id"]: x["uer"] for x in load_tsv_to_dicts(args.cer_tsv_path) + } + manifest_by_split = {split: defaultdict(list) for split in args.splits} + for sample in tqdm(samples): + sample_id, split = sample["id"], sample["split"] + + if args.snr_threshold is not None and "snr" in sample \ + and sample["snr"] < args.snr_threshold: + continue + if args.cer_threshold is not None \ + and id_to_cer[sample_id] > args.cer_threhold: + continue + + normalized_utt = sample["tgt_text"] + if id_to_alignment is not None: + normalized_utt = " ".join(id_to_alignment[sample_id].tokens) + elif args.ipa_vocab: + normalized_utt = ipa_phonemize( + normalized_utt, lang=args.lang, use_g2p=args.use_g2p + ) + manifest_by_split[split]["id"].append(sample_id) + manifest_by_split[split]["audio"].append(audio_paths[sample_id]) + manifest_by_split[split]["n_frames"].append(audio_lengths[sample_id]) + manifest_by_split[split]["tgt_text"].append(normalized_utt) + manifest_by_split[split]["speaker"].append(sample["speaker"]) + manifest_by_split[split]["src_text"].append(sample["src_text"]) + if args.add_fastspeech_targets: + assert id_to_alignment is not None + duration = " ".join( + str(d) for d in id_to_alignment[sample_id].frame_durations + ) + manifest_by_split[split]["duration"].append(duration) + manifest_by_split[split]["pitch"].append(pitch_paths[sample_id]) + manifest_by_split[split]["energy"].append(energy_paths[sample_id]) + for split in args.splits: + save_df_to_tsv( + pd.DataFrame.from_dict(manifest_by_split[split]), + out_root / f"{split}.tsv" + ) + # Generate vocab + vocab_name, spm_filename = None, None + if id_to_alignment is not None or args.ipa_vocab: + vocab = Counter() + for t in manifest_by_split["train"]["tgt_text"]: + vocab.update(t.split(" ")) + vocab_name = "vocab.txt" + with open(out_root / vocab_name, "w") as f: + for s, c in vocab.most_common(): + f.write(f"{s} {c}\n") + else: + spm_filename_prefix = "spm_char" + spm_filename = f"{spm_filename_prefix}.model" + with NamedTemporaryFile(mode="w") as f: + for t in manifest_by_split["train"]["tgt_text"]: + f.write(t + "\n") + f.flush() # needed to ensure gen_vocab sees dumped text + gen_vocab(Path(f.name), out_root / spm_filename_prefix, "char") + # Generate speaker list + speakers = sorted({sample["speaker"] for sample in samples}) + speakers_path = out_root / "speakers.txt" + with open(speakers_path, "w") as f: + for speaker in speakers: + f.write(f"{speaker}\n") + # Generate config YAML + win_len_t = args.win_length / args.sample_rate + hop_len_t = args.hop_length / args.sample_rate + extra = { + "sample_rate": args.sample_rate, + "features": { + "type": "spectrogram+melscale+log", + "eps": 1e-5, "n_mels": args.n_mels, "n_fft": args.n_fft, + "window_fn": "hann", "win_length": args.win_length, + "hop_length": args.hop_length, "sample_rate": args.sample_rate, + "win_len_t": win_len_t, "hop_len_t": hop_len_t, + "f_min": args.f_min, "f_max": args.f_max, + "n_stft": args.n_fft // 2 + 1 + } + } + if len(speakers) > 1: + extra["speaker_set_filename"] = "speakers.txt" + if args.add_fastspeech_targets: + pitch_min, pitch_max = get_feature_value_min_max( + [(out_root / n).as_posix() for n in pitch_paths.values()] + ) + energy_min, energy_max = get_feature_value_min_max( + [(out_root / n).as_posix() for n in energy_paths.values()] + ) + extra["features"]["pitch_min"] = pitch_min + extra["features"]["pitch_max"] = pitch_max + extra["features"]["energy_min"] = energy_min + extra["features"]["energy_max"] = energy_max + gen_config_yaml( + out_root, spm_filename=spm_filename, vocab_name=vocab_name, + audio_root=out_root.as_posix(), input_channels=None, + input_feat_per_channel=None, specaugment_policy=None, + cmvn_type="global", gcmvn_path=gcmvn_npz_path, extra=extra + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--audio-manifest-root", "-m", required=True, type=str) + parser.add_argument("--output-root", "-o", required=True, type=str) + parser.add_argument("--splits", "-s", type=str, nargs="+", + default=["train", "dev", "test"]) + parser.add_argument("--ipa-vocab", action="store_true") + parser.add_argument("--use-g2p", action="store_true") + parser.add_argument("--lang", type=str, default="en-us") + parser.add_argument("--win-length", type=int, default=1024) + parser.add_argument("--hop-length", type=int, default=256) + parser.add_argument("--n-fft", type=int, default=1024) + parser.add_argument("--n-mels", type=int, default=80) + parser.add_argument("--f-min", type=int, default=20) + parser.add_argument("--f-max", type=int, default=8000) + parser.add_argument("--sample-rate", type=int, default=22050) + parser.add_argument("--normalize-volume", "-n", action="store_true") + parser.add_argument("--textgrid-zip", type=str, default=None) + parser.add_argument("--id-to-units-tsv", type=str, default=None) + parser.add_argument("--add-fastspeech-targets", action="store_true") + parser.add_argument("--snr-threshold", type=float, default=None) + parser.add_argument("--cer-threshold", type=float, default=None) + parser.add_argument("--cer-tsv-path", type=str, default="") + args = parser.parse_args() + + process(args) + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/speech_synthesis/preprocessing/get_ljspeech_audio_manifest.py b/fairseq/examples/speech_synthesis/preprocessing/get_ljspeech_audio_manifest.py new file mode 100644 index 0000000000000000000000000000000000000000..7ec1fb7521b8a9b821d28bcaaaedb034f6e95e0b --- /dev/null +++ b/fairseq/examples/speech_synthesis/preprocessing/get_ljspeech_audio_manifest.py @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +from pathlib import Path +from collections import defaultdict + +import pandas as pd +from torchaudio.datasets import LJSPEECH +from tqdm import tqdm + +from examples.speech_to_text.data_utils import save_df_to_tsv + + +log = logging.getLogger(__name__) + +SPLITS = ["train", "dev", "test"] + + +def process(args): + out_root = Path(args.output_data_root).absolute() + out_root.mkdir(parents=True, exist_ok=True) + + # Generate TSV manifest + print("Generating manifest...") + # following FastSpeech's splits + dataset = LJSPEECH(out_root.as_posix(), download=True) + id_to_split = {} + for x in dataset._flist: + id_ = x[0] + speaker = id_.split("-")[0] + id_to_split[id_] = { + "LJ001": "test", "LJ002": "test", "LJ003": "dev" + }.get(speaker, "train") + manifest_by_split = {split: defaultdict(list) for split in SPLITS} + progress = tqdm(enumerate(dataset), total=len(dataset)) + for i, (waveform, _, utt, normalized_utt) in progress: + sample_id = dataset._flist[i][0] + split = id_to_split[sample_id] + manifest_by_split[split]["id"].append(sample_id) + audio_path = f"{dataset._path}/{sample_id}.wav" + manifest_by_split[split]["audio"].append(audio_path) + manifest_by_split[split]["n_frames"].append(len(waveform[0])) + manifest_by_split[split]["tgt_text"].append(normalized_utt) + manifest_by_split[split]["speaker"].append("ljspeech") + manifest_by_split[split]["src_text"].append(utt) + + manifest_root = Path(args.output_manifest_root).absolute() + manifest_root.mkdir(parents=True, exist_ok=True) + for split in SPLITS: + save_df_to_tsv( + pd.DataFrame.from_dict(manifest_by_split[split]), + manifest_root / f"{split}.audio.tsv" + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--output-data-root", "-d", required=True, type=str) + parser.add_argument("--output-manifest-root", "-m", required=True, type=str) + args = parser.parse_args() + + process(args) + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/speech_synthesis/preprocessing/get_speaker_embedding.py b/fairseq/examples/speech_synthesis/preprocessing/get_speaker_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..0e3e4c5cd7aef15dae0b41b0ec7b33e17f66597f --- /dev/null +++ b/fairseq/examples/speech_synthesis/preprocessing/get_speaker_embedding.py @@ -0,0 +1,89 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +from collections import defaultdict +from itertools import chain +from pathlib import Path + +import numpy as np +import torchaudio +import torchaudio.sox_effects as ta_sox +import yaml +from tqdm import tqdm + +from examples.speech_to_text.data_utils import load_tsv_to_dicts +from examples.speech_synthesis.preprocessing.speaker_embedder import SpkrEmbedder + + +def extract_embedding(audio_path, embedder): + wav, sr = torchaudio.load(audio_path) # 2D + if sr != embedder.RATE: + wav, sr = ta_sox.apply_effects_tensor( + wav, sr, [["rate", str(embedder.RATE)]] + ) + try: + emb = embedder([wav[0].cuda().float()]).cpu().numpy() + except RuntimeError: + emb = None + return emb + + +def process(args): + print("Fetching data...") + raw_manifest_root = Path(args.raw_manifest_root).absolute() + samples = [load_tsv_to_dicts(raw_manifest_root / (s + ".tsv")) + for s in args.splits] + samples = list(chain(*samples)) + with open(args.config, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + with open(f"{config['audio_root']}/{config['speaker_set_filename']}") as f: + speaker_to_id = {r.strip(): i for i, r in enumerate(f)} + + embedder = SpkrEmbedder(args.ckpt).cuda() + speaker_to_cnt = defaultdict(float) + speaker_to_emb = defaultdict(float) + for sample in tqdm(samples, desc="extract emb"): + emb = extract_embedding(sample["audio"], embedder) + if emb is not None: + speaker_to_cnt[sample["speaker"]] += 1 + speaker_to_emb[sample["speaker"]] += emb + if len(speaker_to_emb) != len(speaker_to_id): + missed = set(speaker_to_id) - set(speaker_to_emb.keys()) + print( + f"WARNING: missing embeddings for {len(missed)} speaker:\n{missed}" + ) + speaker_emb_mat = np.zeros((len(speaker_to_id), len(emb)), float) + for speaker in speaker_to_emb: + idx = speaker_to_id[speaker] + emb = speaker_to_emb[speaker] + cnt = speaker_to_cnt[speaker] + speaker_emb_mat[idx, :] = emb / cnt + speaker_emb_name = "speaker_emb.npy" + speaker_emb_path = f"{config['audio_root']}/{speaker_emb_name}" + np.save(speaker_emb_path, speaker_emb_mat) + config["speaker_emb_filename"] = speaker_emb_name + + with open(args.new_config, "w") as f: + yaml.dump(config, f) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--raw-manifest-root", "-m", required=True, type=str) + parser.add_argument("--splits", "-s", type=str, nargs="+", + default=["train"]) + parser.add_argument("--config", "-c", required=True, type=str) + parser.add_argument("--new-config", "-n", required=True, type=str) + parser.add_argument("--ckpt", required=True, type=str, + help="speaker embedder checkpoint") + args = parser.parse_args() + + process(args) + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/speech_synthesis/preprocessing/get_vctk_audio_manifest.py b/fairseq/examples/speech_synthesis/preprocessing/get_vctk_audio_manifest.py new file mode 100644 index 0000000000000000000000000000000000000000..7afa40fcd195465a225c9f251734e84fe6b3c7ef --- /dev/null +++ b/fairseq/examples/speech_synthesis/preprocessing/get_vctk_audio_manifest.py @@ -0,0 +1,79 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import numpy as np +import re +from pathlib import Path +from collections import defaultdict + +import pandas as pd +from torchaudio.datasets import VCTK +from tqdm import tqdm + +from examples.speech_to_text.data_utils import save_df_to_tsv + + +log = logging.getLogger(__name__) + +SPLITS = ["train", "dev", "test"] + + +def normalize_text(text): + return re.sub(r"[^a-zA-Z.?!,'\- ]", '', text) + + +def process(args): + out_root = Path(args.output_data_root).absolute() + out_root.mkdir(parents=True, exist_ok=True) + + # Generate TSV manifest + print("Generating manifest...") + dataset = VCTK(out_root.as_posix(), download=False) + ids = list(dataset._walker) + np.random.seed(args.seed) + np.random.shuffle(ids) + n_train = len(ids) - args.n_dev - args.n_test + _split = ["train"] * n_train + ["dev"] * args.n_dev + ["test"] * args.n_test + id_to_split = dict(zip(ids, _split)) + manifest_by_split = {split: defaultdict(list) for split in SPLITS} + progress = tqdm(enumerate(dataset), total=len(dataset)) + for i, (waveform, _, text, speaker_id, _) in progress: + sample_id = dataset._walker[i] + _split = id_to_split[sample_id] + audio_dir = Path(dataset._path) / dataset._folder_audio / speaker_id + audio_path = audio_dir / f"{sample_id}.wav" + text = normalize_text(text) + manifest_by_split[_split]["id"].append(sample_id) + manifest_by_split[_split]["audio"].append(audio_path.as_posix()) + manifest_by_split[_split]["n_frames"].append(len(waveform[0])) + manifest_by_split[_split]["tgt_text"].append(text) + manifest_by_split[_split]["speaker"].append(speaker_id) + manifest_by_split[_split]["src_text"].append(text) + + manifest_root = Path(args.output_manifest_root).absolute() + manifest_root.mkdir(parents=True, exist_ok=True) + for _split in SPLITS: + save_df_to_tsv( + pd.DataFrame.from_dict(manifest_by_split[_split]), + manifest_root / f"{_split}.audio.tsv" + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--output-data-root", "-d", required=True, type=str) + parser.add_argument("--output-manifest-root", "-m", required=True, type=str) + parser.add_argument("--n-dev", default=50, type=int) + parser.add_argument("--n-test", default=100, type=int) + parser.add_argument("--seed", "-s", default=1234, type=int) + args = parser.parse_args() + + process(args) + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/speech_synthesis/preprocessing/speaker_embedder/__init__.py b/fairseq/examples/speech_synthesis/preprocessing/speaker_embedder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3b178676ba322ef613df42977cb498101f841b09 --- /dev/null +++ b/fairseq/examples/speech_synthesis/preprocessing/speaker_embedder/__init__.py @@ -0,0 +1,135 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import librosa +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data +import torchaudio + + +EMBEDDER_PARAMS = { + 'num_mels': 40, + 'n_fft': 512, + 'emb_dim': 256, + 'lstm_hidden': 768, + 'lstm_layers': 3, + 'window': 80, + 'stride': 40, +} + + +def set_requires_grad(nets, requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary + computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad + + +class LinearNorm(nn.Module): + def __init__(self, hp): + super(LinearNorm, self).__init__() + self.linear_layer = nn.Linear(hp["lstm_hidden"], hp["emb_dim"]) + + def forward(self, x): + return self.linear_layer(x) + + +class SpeechEmbedder(nn.Module): + def __init__(self, hp): + super(SpeechEmbedder, self).__init__() + self.lstm = nn.LSTM(hp["num_mels"], + hp["lstm_hidden"], + num_layers=hp["lstm_layers"], + batch_first=True) + self.proj = LinearNorm(hp) + self.hp = hp + + def forward(self, mel): + # (num_mels, T) -> (num_mels, T', window) + mels = mel.unfold(1, self.hp["window"], self.hp["stride"]) + mels = mels.permute(1, 2, 0) # (T', window, num_mels) + x, _ = self.lstm(mels) # (T', window, lstm_hidden) + x = x[:, -1, :] # (T', lstm_hidden), use last frame only + x = self.proj(x) # (T', emb_dim) + x = x / torch.norm(x, p=2, dim=1, keepdim=True) # (T', emb_dim) + + x = x.mean(dim=0) + if x.norm(p=2) != 0: + x = x / x.norm(p=2) + return x + + +class SpkrEmbedder(nn.Module): + RATE = 16000 + + def __init__( + self, + embedder_path, + embedder_params=EMBEDDER_PARAMS, + rate=16000, + hop_length=160, + win_length=400, + pad=False, + ): + super(SpkrEmbedder, self).__init__() + embedder_pt = torch.load(embedder_path, map_location="cpu") + self.embedder = SpeechEmbedder(embedder_params) + self.embedder.load_state_dict(embedder_pt) + self.embedder.eval() + set_requires_grad(self.embedder, requires_grad=False) + self.embedder_params = embedder_params + + self.register_buffer('mel_basis', torch.from_numpy( + librosa.filters.mel( + sr=self.RATE, + n_fft=self.embedder_params["n_fft"], + n_mels=self.embedder_params["num_mels"]) + ) + ) + + self.resample = None + if rate != self.RATE: + self.resample = torchaudio.transforms.Resample(rate, self.RATE) + self.hop_length = hop_length + self.win_length = win_length + self.pad = pad + + def get_mel(self, y): + if self.pad and y.shape[-1] < 14000: + y = F.pad(y, (0, 14000 - y.shape[-1])) + + window = torch.hann_window(self.win_length).to(y) + y = torch.stft(y, n_fft=self.embedder_params["n_fft"], + hop_length=self.hop_length, + win_length=self.win_length, + window=window) + magnitudes = torch.norm(y, dim=-1, p=2) ** 2 + mel = torch.log10(self.mel_basis @ magnitudes + 1e-6) + return mel + + def forward(self, inputs): + dvecs = [] + for wav in inputs: + mel = self.get_mel(wav) + if mel.dim() == 3: + mel = mel.squeeze(0) + dvecs += [self.embedder(mel)] + dvecs = torch.stack(dvecs) + + dvec = torch.mean(dvecs, dim=0) + dvec = dvec / torch.norm(dvec) + + return dvec diff --git a/fairseq/examples/speech_synthesis/preprocessing/vad/__init__.py b/fairseq/examples/speech_synthesis/preprocessing/vad/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9cf121081fbde2f5085ed380f0841649d143a4be --- /dev/null +++ b/fairseq/examples/speech_synthesis/preprocessing/vad/__init__.py @@ -0,0 +1,192 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import collections +import contextlib +import wave + +try: + import webrtcvad +except ImportError: + raise ImportError("Please install py-webrtcvad: pip install webrtcvad") +import argparse +import os +import logging +from tqdm import tqdm + +AUDIO_SUFFIX = '.wav' +FS_MS = 30 +SCALE = 6e-5 +THRESHOLD = 0.3 + + +def read_wave(path): + """Reads a .wav file. + Takes the path, and returns (PCM audio data, sample rate). + """ + with contextlib.closing(wave.open(path, 'rb')) as wf: + num_channels = wf.getnchannels() + assert num_channels == 1 + sample_width = wf.getsampwidth() + assert sample_width == 2 + sample_rate = wf.getframerate() + assert sample_rate in (8000, 16000, 32000, 48000) + pcm_data = wf.readframes(wf.getnframes()) + return pcm_data, sample_rate + + +def write_wave(path, audio, sample_rate): + """Writes a .wav file. + Takes path, PCM audio data, and sample rate. + """ + with contextlib.closing(wave.open(path, 'wb')) as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sample_rate) + wf.writeframes(audio) + + +class Frame(object): + """Represents a "frame" of audio data.""" + def __init__(self, bytes, timestamp, duration): + self.bytes = bytes + self.timestamp = timestamp + self.duration = duration + + +def frame_generator(frame_duration_ms, audio, sample_rate): + """Generates audio frames from PCM audio data. + Takes the desired frame duration in milliseconds, the PCM data, and + the sample rate. + Yields Frames of the requested duration. + """ + n = int(sample_rate * (frame_duration_ms / 1000.0) * 2) + offset = 0 + timestamp = 0.0 + duration = (float(n) / sample_rate) / 2.0 + while offset + n < len(audio): + yield Frame(audio[offset:offset + n], timestamp, duration) + timestamp += duration + offset += n + + +def vad_collector(sample_rate, frame_duration_ms, + padding_duration_ms, vad, frames): + """Filters out non-voiced audio frames. + Given a webrtcvad.Vad and a source of audio frames, yields only + the voiced audio. + Uses a padded, sliding window algorithm over the audio frames. + When more than 90% of the frames in the window are voiced (as + reported by the VAD), the collector triggers and begins yielding + audio frames. Then the collector waits until 90% of the frames in + the window are unvoiced to detrigger. + The window is padded at the front and back to provide a small + amount of silence or the beginnings/endings of speech around the + voiced frames. + Arguments: + sample_rate - The audio sample rate, in Hz. + frame_duration_ms - The frame duration in milliseconds. + padding_duration_ms - The amount to pad the window, in milliseconds. + vad - An instance of webrtcvad.Vad. + frames - a source of audio frames (sequence or generator). + Returns: A generator that yields PCM audio data. + """ + num_padding_frames = int(padding_duration_ms / frame_duration_ms) + # We use a deque for our sliding window/ring buffer. + ring_buffer = collections.deque(maxlen=num_padding_frames) + # We have two states: TRIGGERED and NOTTRIGGERED. We start in the + # NOTTRIGGERED state. + triggered = False + + voiced_frames = [] + for frame in frames: + is_speech = vad.is_speech(frame.bytes, sample_rate) + + # sys.stdout.write('1' if is_speech else '0') + if not triggered: + ring_buffer.append((frame, is_speech)) + num_voiced = len([f for f, speech in ring_buffer if speech]) + # If we're NOTTRIGGERED and more than 90% of the frames in + # the ring buffer are voiced frames, then enter the + # TRIGGERED state. + if num_voiced > 0.9 * ring_buffer.maxlen: + triggered = True + # We want to yield all the audio we see from now until + # we are NOTTRIGGERED, but we have to start with the + # audio that's already in the ring buffer. + for f, _ in ring_buffer: + voiced_frames.append(f) + ring_buffer.clear() + else: + # We're in the TRIGGERED state, so collect the audio data + # and add it to the ring buffer. + voiced_frames.append(frame) + ring_buffer.append((frame, is_speech)) + num_unvoiced = len([f for f, speech in ring_buffer if not speech]) + # If more than 90% of the frames in the ring buffer are + # unvoiced, then enter NOTTRIGGERED and yield whatever + # audio we've collected. + if num_unvoiced > 0.9 * ring_buffer.maxlen: + triggered = False + yield [b''.join([f.bytes for f in voiced_frames]), + voiced_frames[0].timestamp, voiced_frames[-1].timestamp] + ring_buffer.clear() + voiced_frames = [] + # If we have any leftover voiced audio when we run out of input, + # yield it. + if voiced_frames: + yield [b''.join([f.bytes for f in voiced_frames]), + voiced_frames[0].timestamp, voiced_frames[-1].timestamp] + + +def main(args): + # create output folder + try: + cmd = f"mkdir -p {args.out_path}" + os.system(cmd) + except Exception: + logging.error("Can not create output folder") + exit(-1) + + # build vad object + vad = webrtcvad.Vad(int(args.agg)) + # iterating over wavs in dir + for file in tqdm(os.listdir(args.in_path)): + if file.endswith(AUDIO_SUFFIX): + audio_inpath = os.path.join(args.in_path, file) + audio_outpath = os.path.join(args.out_path, file) + audio, sample_rate = read_wave(audio_inpath) + frames = frame_generator(FS_MS, audio, sample_rate) + frames = list(frames) + segments = vad_collector(sample_rate, FS_MS, 300, vad, frames) + merge_segments = list() + timestamp_start = 0.0 + timestamp_end = 0.0 + # removing start, end, and long sequences of sils + for i, segment in enumerate(segments): + merge_segments.append(segment[0]) + if i and timestamp_start: + sil_duration = segment[1] - timestamp_end + if sil_duration > THRESHOLD: + merge_segments.append(int(THRESHOLD / SCALE)*(b'\x00')) + else: + merge_segments.append(int((sil_duration / SCALE))*(b'\x00')) + timestamp_start = segment[1] + timestamp_end = segment[2] + segment = b''.join(merge_segments) + write_wave(audio_outpath, segment, sample_rate) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Apply vad to a file of fils.') + parser.add_argument('in_path', type=str, help='Path to the input files') + parser.add_argument('out_path', type=str, + help='Path to save the processed files') + parser.add_argument('--agg', type=int, default=3, + help='The level of aggressiveness of the VAD: [0-3]') + args = parser.parse_args() + + main(args) diff --git a/fairseq/examples/speech_synthesis/utils.py b/fairseq/examples/speech_synthesis/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2c7b03733d2290d3834d2c68a16034198daa1e69 --- /dev/null +++ b/fairseq/examples/speech_synthesis/utils.py @@ -0,0 +1,101 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from scipy.interpolate import interp1d +import torchaudio + +from fairseq.tasks.text_to_speech import ( + batch_compute_distortion, compute_rms_dist +) + + +def batch_mel_spectral_distortion( + y1, y2, sr, normalize_type="path", mel_fn=None +): + """ + https://arxiv.org/pdf/2011.03568.pdf + + Same as Mel Cepstral Distortion, but computed on log-mel spectrograms. + """ + if mel_fn is None or mel_fn.sample_rate != sr: + mel_fn = torchaudio.transforms.MelSpectrogram( + sr, n_fft=int(0.05 * sr), win_length=int(0.05 * sr), + hop_length=int(0.0125 * sr), f_min=20, n_mels=80, + window_fn=torch.hann_window + ).to(y1[0].device) + offset = 1e-6 + return batch_compute_distortion( + y1, y2, sr, lambda y: torch.log(mel_fn(y) + offset).transpose(-1, -2), + compute_rms_dist, normalize_type + ) + + +# This code is based on +# "https://github.com/bastibe/MAPS-Scripts/blob/master/helper.py" +def _same_t_in_true_and_est(func): + def new_func(true_t, true_f, est_t, est_f): + assert type(true_t) is np.ndarray + assert type(true_f) is np.ndarray + assert type(est_t) is np.ndarray + assert type(est_f) is np.ndarray + + interpolated_f = interp1d( + est_t, est_f, bounds_error=False, kind='nearest', fill_value=0 + )(true_t) + return func(true_t, true_f, true_t, interpolated_f) + + return new_func + + +@_same_t_in_true_and_est +def gross_pitch_error(true_t, true_f, est_t, est_f): + """The relative frequency in percent of pitch estimates that are + outside a threshold around the true pitch. Only frames that are + considered pitched by both the ground truth and the estimator (if + applicable) are considered. + """ + + correct_frames = _true_voiced_frames(true_t, true_f, est_t, est_f) + gross_pitch_error_frames = _gross_pitch_error_frames( + true_t, true_f, est_t, est_f + ) + return np.sum(gross_pitch_error_frames) / np.sum(correct_frames) + + +def _gross_pitch_error_frames(true_t, true_f, est_t, est_f, eps=1e-8): + voiced_frames = _true_voiced_frames(true_t, true_f, est_t, est_f) + true_f_p_eps = [x + eps for x in true_f] + pitch_error_frames = np.abs(est_f / true_f_p_eps - 1) > 0.2 + return voiced_frames & pitch_error_frames + + +def _true_voiced_frames(true_t, true_f, est_t, est_f): + return (est_f != 0) & (true_f != 0) + + +def _voicing_decision_error_frames(true_t, true_f, est_t, est_f): + return (est_f != 0) != (true_f != 0) + + +@_same_t_in_true_and_est +def f0_frame_error(true_t, true_f, est_t, est_f): + gross_pitch_error_frames = _gross_pitch_error_frames( + true_t, true_f, est_t, est_f + ) + voicing_decision_error_frames = _voicing_decision_error_frames( + true_t, true_f, est_t, est_f + ) + return (np.sum(gross_pitch_error_frames) + + np.sum(voicing_decision_error_frames)) / (len(true_t)) + + +@_same_t_in_true_and_est +def voicing_decision_error(true_t, true_f, est_t, est_f): + voicing_decision_error_frames = _voicing_decision_error_frames( + true_t, true_f, est_t, est_f + ) + return np.sum(voicing_decision_error_frames) / (len(true_t)) diff --git a/fairseq/examples/speech_text_joint_to_text/README.md b/fairseq/examples/speech_text_joint_to_text/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c1aa11929a434a0d146dcfa05fc067d0fac1b310 --- /dev/null +++ b/fairseq/examples/speech_text_joint_to_text/README.md @@ -0,0 +1,51 @@ +# Joint Speech Text training in Fairseq +An extension of Fairseq s2t project with the speech to text task enhanced by the co-trained text to text mapping task. More details about Fairseq s2t can be found [here](../speech_to_text/README.md) + +## Examples +Examples of speech text joint training in fairseq +- [English-to-German MuST-C model](docs/ende-mustc.md) +- [IWSLT 2021 Multilingual Speech Translation](docs/iwslt2021.md) +- [Speech Text Joint Pre-training ](docs/pre-training.md) +## Citation +Please cite as: +``` +@inproceedings{Tang2022UnifiedSP, + title={Unified Speech-Text Pre-training for Speech Translation and Recognition}, + author={Yun Tang and Hongyu Gong and Ning Dong and Changhan Wang and Wei-Ning Hsu and Jiatao Gu and Alexei Baevski and Xian Li and Abdelrahman Mohamed and Michael Auli and Juan Miguel Pino}, + booktitle={ACL}, + year={2022} +} +@inproceedings{Tang2021IST, + title = {Improving Speech Translation by Understanding and Learning from the Auxiliary Text Translation Task}, + author = {Yun Tang and Juan Pino and Xian Li and Changhan Wang and Dmitriy Genzel}, + booktitle = {ACL}, + year = {2021}, +} + +@inproceedings{Tang2021FST, + title = {FST: the FAIR Speech Translation System for the IWSLT21 Multilingual Shared Task}, + author = {Yun Tang and Hongyu Gong and Xian Li and Changhan Wang and Juan Pino and Holger Schwenk and Naman Goyal}, + booktitle = {IWSLT}, + year = {2021}, +} +@inproceedings{Tang2021AGM, + title={A General Multi-Task Learning Framework to Leverage Text Data for Speech to Text Tasks}, + author={Yun Tang and J. Pino and Changhan Wang and Xutai Ma and Dmitriy Genzel}, + booktitle={ICASSP}, + year={2021} +} + +@inproceedings{wang2020fairseqs2t, + title = {fairseq S2T: Fast Speech-to-Text Modeling with fairseq}, + author = {Changhan Wang and Yun Tang and Xutai Ma and Anne Wu and Dmytro Okhonko and Juan Pino}, + booktitle = {Proceedings of the 2020 Conference of the Asian Chapter of the Association for Computational Linguistics (AACL): System Demonstrations}, + year = {2020}, +} + +@inproceedings{ott2019fairseq, + title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling}, + author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli}, + booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations}, + year = {2019}, +} +``` diff --git a/fairseq/examples/speech_text_joint_to_text/__init__.py b/fairseq/examples/speech_text_joint_to_text/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..239d2e69f9a235095dee1ea7b3a94164a77273f5 --- /dev/null +++ b/fairseq/examples/speech_text_joint_to_text/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import tasks, criterions, models # noqa diff --git a/fairseq/examples/speech_text_joint_to_text/configs/mustc_noise.list b/fairseq/examples/speech_text_joint_to_text/configs/mustc_noise.list new file mode 100644 index 0000000000000000000000000000000000000000..02eeac4e009f77b765004272f59a1618214da18d --- /dev/null +++ b/fairseq/examples/speech_text_joint_to_text/configs/mustc_noise.list @@ -0,0 +1,49 @@ +"(Applause) NOISE +"(Laughter) VOICE +"(Laughter)" VOICE +(Applause) NOISE +(Applause). NOISE +(Audience) VOICE +(Audio) NOISE +(Beat) NOISE +(Beatboxing) VOICE +(Beep) NOISE +(Beeps) NOISE +(Cheering) VOICE +(Cheers) VOICE +(Claps) NOISE +(Clicking) NOISE +(Clunk) NOISE +(Coughs) NOISE +(Drums) NOISE +(Explosion) NOISE +(Gasps) VOICE +(Guitar) NOISE +(Honk) NOISE +(Laugher) VOICE +(Laughing) VOICE +(Laughs) VOICE +(Laughter) VOICE +(Laughter). VOICE +(Laughter)... VOICE +(Mumbling) VOICE +(Music) NOISE +(Noise) NOISE +(Recording) VOICE +(Ringing) NOISE +(Shouts) VOICE +(Sigh) VOICE +(Sighs) VOICE +(Silence) NOISE +(Singing) VOICE +(Sings) VOICE +(Spanish) VOICE +(Static) NOISE +(Tones) NOISE +(Trumpet) NOISE +(Video) NOISE +(Video): NOISE +(Voice-over) NOISE +(Whistle) NOISE +(Whistling) NOISE +(video): NOISE diff --git a/fairseq/examples/speech_text_joint_to_text/criterions/__init__.py b/fairseq/examples/speech_text_joint_to_text/criterions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7faae73119321af0b34fe8e26499a2ef5577291a --- /dev/null +++ b/fairseq/examples/speech_text_joint_to_text/criterions/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + + +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith(".py") and not file.startswith("_"): + criterion_name = file[: file.find(".py")] + importlib.import_module( + "examples.speech_text_joint_to_text.criterions." + criterion_name + ) diff --git a/fairseq/examples/speech_text_joint_to_text/criterions/multi_modality_compound.py b/fairseq/examples/speech_text_joint_to_text/criterions/multi_modality_compound.py new file mode 100644 index 0000000000000000000000000000000000000000..b3a5506a2d29436e33776ff956e066205d907f99 --- /dev/null +++ b/fairseq/examples/speech_text_joint_to_text/criterions/multi_modality_compound.py @@ -0,0 +1,181 @@ +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import logging +import math +from dataclasses import dataclass, field + +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.criterions.ctc import CtcCriterion, CtcCriterionConfig +from fairseq.criterions.label_smoothed_cross_entropy import ( + LabelSmoothedCrossEntropyCriterionConfig, +) +from fairseq.logging.meters import safe_round + +from .multi_modality_cross_entropy import SpeechTextPreTrainCrossEntCriterion + +logger = logging.getLogger(__name__) + + +@dataclass +class SpeechTextPreTrainCompoundCriterionConfig( + LabelSmoothedCrossEntropyCriterionConfig +): + zero_infinity: bool = field( + default=False, + metadata={"help": "zero inf loss when source length <= target length"}, + ) + post_process: str = field( + default="none", + metadata={ + "help": "how to post process predictions into words. can be letter, " + "wordpiece, BPE symbols, etc. " + "See fairseq.data.data_utils.post_process() for full list of options" + }, + ) + + +@register_criterion( + "speech_text_pretrain_compound", dataclass=SpeechTextPreTrainCompoundCriterionConfig +) +class SpeechTextPreTrainCompoundCriterion(FairseqCriterion): + def __init__( + self, + task, + sentence_avg, + label_smoothing, + report_accuracy=False, + zero_infinity=False, + post_process=None, + ): + super().__init__(task) + self.xent = SpeechTextPreTrainCrossEntCriterion( + task, sentence_avg, label_smoothing, report_accuracy + ) + cfg_dict = { + "zero_infinity": zero_infinity, + "sentence_avg": sentence_avg, + "post_process": post_process, + } + cfg_ctc = CtcCriterionConfig(**cfg_dict) + self.ctc = CtcCriterion(cfg_ctc, task) + + def forward(self, model, sample, reduce=True): + mode = sample["net_input"]["mode"] + if mode == "sup_speech_ctc": # CTC + sample["net_input"][ + "src_lengths" + ] = None # get downsampled src_lengths from padding_mask + loss, sample_size, logging_output = self.ctc(model, sample, reduce) + logging_output["mode"] = SpeechTextPreTrainCompoundCriterion.mode2value( + "CTC" + ) + else: + loss, sample_size, logging_output = self.xent(model, sample, reduce) + logging_output["mode"] = SpeechTextPreTrainCompoundCriterion.mode2value( + "xent" + ) + + return loss, sample_size, logging_output + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True + + @staticmethod + def mode2value(mode): # make the logging_outputs_can_be_summed = True + if mode == "CTC": + return 907 # prime number + if mode == "xent": + return 887 # prime number + return 0 + + @staticmethod + def value2mode(value): + if value % 907 == 0: + return "CTC" + if value % 887 == 0: + return "xent" + raise ValueError("Unknow mode") + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + + def _get_mode(logging_outputs): + mds = [ + SpeechTextPreTrainCompoundCriterion.value2mode(log["mode"]) + for log in logging_outputs + ] + if sum([1 if l != mds[0] else 0 for l in mds]) > 0: + raise ValueError("mode in one mini-batch is expected to be the same!") + return mds[0] + + log_mode = _get_mode(logging_outputs) + if log_mode == "xent": + return SpeechTextPreTrainCrossEntCriterion.reduce_metrics(logging_outputs) + + # ctc loss + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) + nsentences = utils.item( + sum(log.get("nsentences", 0) for log in logging_outputs) + ) + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + + metrics.log_scalar( + "ctc_loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar("ctc_ntokens", ntokens) + metrics.log_scalar("ctc_nsentences", nsentences) + if sample_size != ntokens: + metrics.log_scalar( + "ctc_nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + + c_errors = sum(log.get("c_errors", 0) for log in logging_outputs) + metrics.log_scalar("_c_errors", c_errors) + c_total = sum(log.get("c_total", 0) for log in logging_outputs) + metrics.log_scalar("_c_total", c_total) + w_errors = sum(log.get("w_errors", 0) for log in logging_outputs) + metrics.log_scalar("_w_errors", w_errors) + wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs) + metrics.log_scalar("_wv_errors", wv_errors) + w_total = sum(log.get("w_total", 0) for log in logging_outputs) + metrics.log_scalar("_w_total", w_total) + + if c_total > 0: + metrics.log_derived( + "uer", + lambda meters: safe_round( + meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3 + ) + if meters["_c_total"].sum > 0 + else float("nan"), + ) + if w_total > 0: + metrics.log_derived( + "wer", + lambda meters: safe_round( + meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3 + ) + if meters["_w_total"].sum > 0 + else float("nan"), + ) + metrics.log_derived( + "raw_wer", + lambda meters: safe_round( + meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3 + ) + if meters["_w_total"].sum > 0 + else float("nan"), + )