PyTorch
ssl-aasist
custom_code
ash56 commited on
Commit
a1d9110
·
verified ·
1 Parent(s): 77d5bb2

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fairseq/examples/hubert/tests/sample.base.L9.npy +3 -0
  2. fairseq/examples/hubert/tests/sample.large.L20.npy +3 -0
  3. fairseq/examples/simultaneous_translation/__pycache__/__init__.cpython-310.pyc +0 -0
  4. fairseq/examples/simultaneous_translation/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  5. fairseq/examples/simultaneous_translation/utils/__pycache__/functions.cpython-310.pyc +0 -0
  6. fairseq/examples/simultaneous_translation/utils/__pycache__/p_choose_strategy.cpython-310.pyc +0 -0
  7. fairseq/examples/speech_recognition/README.md +87 -0
  8. fairseq/examples/speech_recognition/__init__.py +1 -0
  9. fairseq/examples/speech_recognition/criterions/ASG_loss.py +170 -0
  10. fairseq/examples/speech_recognition/criterions/__init__.py +17 -0
  11. fairseq/examples/speech_recognition/criterions/cross_entropy_acc.py +130 -0
  12. fairseq/examples/speech_recognition/data/__init__.py +11 -0
  13. fairseq/examples/speech_recognition/data/asr_dataset.py +122 -0
  14. fairseq/examples/speech_recognition/data/collaters.py +131 -0
  15. fairseq/examples/speech_recognition/data/data_utils.py +100 -0
  16. fairseq/examples/speech_recognition/data/replabels.py +70 -0
  17. fairseq/examples/speech_recognition/datasets/asr_prep_json.py +125 -0
  18. fairseq/examples/speech_recognition/datasets/prepare-librispeech.sh +88 -0
  19. fairseq/examples/speech_recognition/infer.py +436 -0
  20. fairseq/examples/speech_recognition/kaldi/__init__.py +0 -0
  21. fairseq/examples/speech_recognition/kaldi/add-self-loop-simple.cc +94 -0
  22. fairseq/examples/speech_recognition/kaldi/config/kaldi_initializer.yaml +8 -0
  23. fairseq/examples/speech_recognition/kaldi/kaldi_decoder.py +244 -0
  24. fairseq/examples/speech_recognition/kaldi/kaldi_initializer.py +698 -0
  25. fairseq/examples/speech_recognition/models/__init__.py +8 -0
  26. fairseq/examples/speech_recognition/models/vggtransformer.py +1020 -0
  27. fairseq/examples/speech_recognition/models/w2l_conv_glu_enc.py +177 -0
  28. fairseq/examples/speech_recognition/new/README.md +43 -0
  29. fairseq/examples/speech_recognition/new/__init__.py +0 -0
  30. fairseq/examples/speech_recognition/new/conf/hydra/sweeper/ax.yaml +29 -0
  31. fairseq/examples/speech_recognition/new/conf/hydra/sweeper/ax_sil.yaml +29 -0
  32. fairseq/examples/speech_recognition/new/conf/infer.yaml +27 -0
  33. fairseq/examples/speech_recognition/new/conf/run_config/fb_slurm_1.yaml +28 -0
  34. fairseq/examples/speech_recognition/new/conf/run_config/fb_slurm_2g.yaml +27 -0
  35. fairseq/examples/speech_recognition/new/decoders/__init__.py +0 -0
  36. fairseq/examples/speech_recognition/new/decoders/base_decoder.py +62 -0
  37. fairseq/examples/speech_recognition/new/decoders/decoder.py +32 -0
  38. fairseq/examples/speech_recognition/new/decoders/decoder_config.py +70 -0
  39. fairseq/examples/speech_recognition/new/decoders/flashlight_decoder.py +433 -0
  40. fairseq/examples/speech_recognition/new/decoders/viterbi_decoder.py +24 -0
  41. fairseq/examples/speech_recognition/new/infer.py +502 -0
  42. fairseq/examples/speech_recognition/tasks/__init__.py +8 -0
  43. fairseq/examples/speech_recognition/tasks/speech_recognition.py +157 -0
  44. fairseq/examples/speech_recognition/utils/wer_utils.py +381 -0
  45. fairseq/examples/speech_recognition/w2l_decoder.py +486 -0
  46. fairseq/examples/speech_synthesis/README.md +38 -0
  47. fairseq/examples/speech_synthesis/__init__.py +4 -0
  48. fairseq/examples/speech_synthesis/data_utils.py +344 -0
  49. fairseq/examples/speech_synthesis/docs/common_voice_example.md +67 -0
  50. fairseq/examples/speech_synthesis/docs/ljspeech_example.md +137 -0
fairseq/examples/hubert/tests/sample.base.L9.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b44dc3a0519f6adb670a5da7410c1405f4fdb0e4866e5fc4983b136480212f34
3
+ size 1831104
fairseq/examples/hubert/tests/sample.large.L20.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c82839134cc2340355eb49b41e7e3517e8e9dfdfaa6fc28f0464cc8ae9569ee
3
+ size 2441408
fairseq/examples/simultaneous_translation/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (258 Bytes). View file
 
fairseq/examples/simultaneous_translation/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (509 Bytes). View file
 
fairseq/examples/simultaneous_translation/utils/__pycache__/functions.cpython-310.pyc ADDED
Binary file (3.29 kB). View file
 
fairseq/examples/simultaneous_translation/utils/__pycache__/p_choose_strategy.cpython-310.pyc ADDED
Binary file (1.79 kB). View file
 
fairseq/examples/speech_recognition/README.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 2021 Update: We are merging this example into the [S2T framework](../speech_to_text), which supports more generic speech-to-text tasks (e.g. speech translation) and more flexible data processing pipelines. Please stay tuned.
2
+
3
+ # Speech Recognition
4
+ `examples/speech_recognition` is implementing ASR task in Fairseq, along with needed features, datasets, models and loss functions to train and infer model described in [Transformers with convolutional context for ASR (Abdelrahman Mohamed et al., 2019)](https://arxiv.org/abs/1904.11660).
5
+
6
+
7
+ ## Additional dependencies
8
+ On top of main fairseq dependencies there are couple more additional requirements.
9
+
10
+ 1) Please follow the instructions to install [torchaudio](https://github.com/pytorch/audio). This is required to compute audio fbank features.
11
+ 2) [Sclite](http://www1.icsi.berkeley.edu/Speech/docs/sctk-1.2/sclite.htm#sclite_name_0) is used to measure WER. Sclite can be downloaded and installed from source from sctk package [here](http://www.openslr.org/4/). Training and inference doesn't require Sclite dependency.
12
+ 3) [sentencepiece](https://github.com/google/sentencepiece) is required in order to create dataset with word-piece targets.
13
+
14
+ ## Preparing librispeech data
15
+ ```
16
+ ./examples/speech_recognition/datasets/prepare-librispeech.sh $DIR_TO_SAVE_RAW_DATA $DIR_FOR_PREPROCESSED_DATA
17
+ ```
18
+
19
+ ## Training librispeech data
20
+ ```
21
+ python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 80 --task speech_recognition --arch vggtransformer_2 --optimizer adadelta --lr 1.0 --adadelta-eps 1e-8 --adadelta-rho 0.95 --clip-norm 10.0 --max-tokens 5000 --log-format json --log-interval 1 --criterion cross_entropy_acc --user-dir examples/speech_recognition/
22
+ ```
23
+
24
+ ## Inference for librispeech
25
+ `$SET` can be `test_clean` or `test_other`
26
+ Any checkpoint in `$MODEL_PATH` can be selected. In this example we are working with `checkpoint_last.pt`
27
+ ```
28
+ python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --max-tokens 25000 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --beam 20 --results-path $RES_DIR --batch-size 40 --gen-subset $SET --user-dir examples/speech_recognition/
29
+ ```
30
+
31
+ ## Inference for librispeech
32
+ ```
33
+ sclite -r ${RES_DIR}/ref.word-checkpoint_last.pt-${SET}.txt -h ${RES_DIR}/hypo.word-checkpoint_last.pt-${SET}.txt -i rm -o all stdout > $RES_REPORT
34
+ ```
35
+ `Sum/Avg` row from first table of the report has WER
36
+
37
+ ## Using flashlight (previously called [wav2letter](https://github.com/facebookresearch/wav2letter)) components
38
+ [flashlight](https://github.com/facebookresearch/flashlight) now has integration with fairseq. Currently this includes:
39
+
40
+ * AutoSegmentationCriterion (ASG)
41
+ * flashlight-style Conv/GLU model
42
+ * flashlight's beam search decoder
43
+
44
+ To use these, follow the instructions on [this page](https://github.com/flashlight/flashlight/tree/e16682fa32df30cbf675c8fe010f929c61e3b833/bindings/python) to install python bindings. **Flashlight v0.3.2** must be used to install the bindings. Running:
45
+ ```
46
+ git clone --branch v0.3.2 https://github.com/flashlight/flashlight
47
+ ```
48
+ will properly clone and check out this version.
49
+
50
+ ## Training librispeech data (flashlight style, Conv/GLU + ASG loss)
51
+ Training command:
52
+ ```
53
+ python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 100 --task speech_recognition --arch w2l_conv_glu_enc --batch-size 4 --optimizer sgd --lr 0.3,0.8 --momentum 0.8 --clip-norm 0.2 --max-tokens 50000 --log-format json --log-interval 100 --num-workers 0 --sentence-avg --criterion asg_loss --asg-transitions-init 5 --max-replabel 2 --linseg-updates 8789 --user-dir examples/speech_recognition
54
+ ```
55
+
56
+ Note that ASG loss currently doesn't do well with word-pieces. You should prepare a dataset with character targets by setting `nbpe=31` in `prepare-librispeech.sh`.
57
+
58
+ ## Inference for librispeech (flashlight decoder, n-gram LM)
59
+ Inference command:
60
+ ```
61
+ python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder kenlm --kenlm-model $KENLM_MODEL_PATH --lexicon $LEXICON_PATH --beam 200 --beam-threshold 15 --lm-weight 1.5 --word-score 1.5 --sil-weight -0.3 --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition
62
+ ```
63
+
64
+ `$KENLM_MODEL_PATH` should be a standard n-gram language model file. `$LEXICON_PATH` should be a flashlight-style lexicon (list of known words and their spellings). For ASG inference, a lexicon line should look like this (note the repetition labels):
65
+ ```
66
+ doorbell D O 1 R B E L 1 ▁
67
+ ```
68
+ For CTC inference with word-pieces, repetition labels are not used and the lexicon should have most common spellings for each word (one can use sentencepiece's `NBestEncodeAsPieces` for this):
69
+ ```
70
+ doorbell ▁DOOR BE LL
71
+ doorbell ▁DOOR B E LL
72
+ doorbell ▁DO OR BE LL
73
+ doorbell ▁DOOR B EL L
74
+ doorbell ▁DOOR BE L L
75
+ doorbell ▁DO OR B E LL
76
+ doorbell ▁DOOR B E L L
77
+ doorbell ▁DO OR B EL L
78
+ doorbell ▁DO O R BE LL
79
+ doorbell ▁DO OR BE L L
80
+ ```
81
+ Lowercase vs. uppercase matters: the *word* should match the case of the n-gram language model (i.e. `$KENLM_MODEL_PATH`), while the *spelling* should match the case of the token dictionary (i.e. `$DIR_FOR_PREPROCESSED_DATA/dict.txt`).
82
+
83
+ ## Inference for librispeech (flashlight decoder, viterbi only)
84
+ Inference command:
85
+ ```
86
+ python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder viterbi --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition
87
+ ```
fairseq/examples/speech_recognition/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import criterions, models, tasks # noqa
fairseq/examples/speech_recognition/criterions/ASG_loss.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import torch
9
+ from examples.speech_recognition.data.replabels import pack_replabels
10
+ from fairseq import utils
11
+ from fairseq.criterions import FairseqCriterion, register_criterion
12
+
13
+
14
+ @register_criterion("asg_loss")
15
+ class ASGCriterion(FairseqCriterion):
16
+ @staticmethod
17
+ def add_args(parser):
18
+ group = parser.add_argument_group("ASG Loss")
19
+ group.add_argument(
20
+ "--asg-transitions-init",
21
+ help="initial diagonal value of transition matrix",
22
+ type=float,
23
+ default=0.0,
24
+ )
25
+ group.add_argument(
26
+ "--max-replabel", help="maximum # of replabels", type=int, default=2
27
+ )
28
+ group.add_argument(
29
+ "--linseg-updates",
30
+ help="# of training updates to use LinSeg initialization",
31
+ type=int,
32
+ default=0,
33
+ )
34
+ group.add_argument(
35
+ "--hide-linseg-messages",
36
+ help="hide messages about LinSeg initialization",
37
+ action="store_true",
38
+ )
39
+
40
+ def __init__(
41
+ self,
42
+ task,
43
+ silence_token,
44
+ asg_transitions_init,
45
+ max_replabel,
46
+ linseg_updates,
47
+ hide_linseg_messages,
48
+ ):
49
+ from flashlight.lib.sequence.criterion import ASGLoss, CriterionScaleMode
50
+
51
+ super().__init__(task)
52
+ self.tgt_dict = task.target_dictionary
53
+ self.eos = self.tgt_dict.eos()
54
+ self.silence = (
55
+ self.tgt_dict.index(silence_token)
56
+ if silence_token in self.tgt_dict
57
+ else None
58
+ )
59
+ self.max_replabel = max_replabel
60
+
61
+ num_labels = len(self.tgt_dict)
62
+ self.asg = ASGLoss(num_labels, scale_mode=CriterionScaleMode.TARGET_SZ_SQRT)
63
+ self.asg.trans = torch.nn.Parameter(
64
+ asg_transitions_init * torch.eye(num_labels), requires_grad=True
65
+ )
66
+
67
+ self.linseg_progress = torch.nn.Parameter(
68
+ torch.tensor([0], dtype=torch.int), requires_grad=False
69
+ )
70
+ self.linseg_maximum = linseg_updates
71
+ self.linseg_message_state = "none" if hide_linseg_messages else "start"
72
+
73
+ @classmethod
74
+ def build_criterion(cls, args, task):
75
+ return cls(
76
+ task,
77
+ args.silence_token,
78
+ args.asg_transitions_init,
79
+ args.max_replabel,
80
+ args.linseg_updates,
81
+ args.hide_linseg_messages,
82
+ )
83
+
84
+ def linseg_step(self):
85
+ if not self.training:
86
+ return False
87
+ if self.linseg_progress.item() < self.linseg_maximum:
88
+ if self.linseg_message_state == "start":
89
+ print("| using LinSeg to initialize ASG")
90
+ self.linseg_message_state = "finish"
91
+ self.linseg_progress.add_(1)
92
+ return True
93
+ elif self.linseg_message_state == "finish":
94
+ print("| finished LinSeg initialization")
95
+ self.linseg_message_state = "none"
96
+ return False
97
+
98
+ def replace_eos_with_silence(self, tgt):
99
+ if tgt[-1] != self.eos:
100
+ return tgt
101
+ elif self.silence is None or (len(tgt) > 1 and tgt[-2] == self.silence):
102
+ return tgt[:-1]
103
+ else:
104
+ return tgt[:-1] + [self.silence]
105
+
106
+ def forward(self, model, sample, reduce=True):
107
+ """Compute the loss for the given sample.
108
+
109
+ Returns a tuple with three elements:
110
+ 1) the loss
111
+ 2) the sample size, which is used as the denominator for the gradient
112
+ 3) logging outputs to display while training
113
+ """
114
+
115
+ net_output = model(**sample["net_input"])
116
+ emissions = net_output["encoder_out"].transpose(0, 1).contiguous()
117
+ B = emissions.size(0)
118
+ T = emissions.size(1)
119
+ device = emissions.device
120
+
121
+ target = torch.IntTensor(B, T)
122
+ target_size = torch.IntTensor(B)
123
+ using_linseg = self.linseg_step()
124
+
125
+ for b in range(B):
126
+ initial_target_size = sample["target_lengths"][b].item()
127
+ if initial_target_size == 0:
128
+ raise ValueError("target size cannot be zero")
129
+
130
+ tgt = sample["target"][b, :initial_target_size].tolist()
131
+ tgt = self.replace_eos_with_silence(tgt)
132
+ tgt = pack_replabels(tgt, self.tgt_dict, self.max_replabel)
133
+ tgt = tgt[:T]
134
+
135
+ if using_linseg:
136
+ tgt = [tgt[t * len(tgt) // T] for t in range(T)]
137
+
138
+ target[b][: len(tgt)] = torch.IntTensor(tgt)
139
+ target_size[b] = len(tgt)
140
+
141
+ loss = self.asg.forward(emissions, target.to(device), target_size.to(device))
142
+
143
+ if reduce:
144
+ loss = torch.sum(loss)
145
+
146
+ sample_size = (
147
+ sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
148
+ )
149
+ logging_output = {
150
+ "loss": utils.item(loss.data) if reduce else loss.data,
151
+ "ntokens": sample["ntokens"],
152
+ "nsentences": sample["target"].size(0),
153
+ "sample_size": sample_size,
154
+ }
155
+ return loss, sample_size, logging_output
156
+
157
+ @staticmethod
158
+ def aggregate_logging_outputs(logging_outputs):
159
+ """Aggregate logging outputs from data parallel training."""
160
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
161
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
162
+ nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
163
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
164
+ agg_output = {
165
+ "loss": loss_sum / nsentences,
166
+ "ntokens": ntokens,
167
+ "nsentences": nsentences,
168
+ "sample_size": sample_size,
169
+ }
170
+ return agg_output
fairseq/examples/speech_recognition/criterions/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+
4
+
5
+ # ASG loss requires flashlight bindings
6
+ files_to_skip = set()
7
+ try:
8
+ import flashlight.lib.sequence.criterion
9
+ except ImportError:
10
+ files_to_skip.add("ASG_loss.py")
11
+
12
+ for file in sorted(os.listdir(os.path.dirname(__file__))):
13
+ if file.endswith(".py") and not file.startswith("_") and file not in files_to_skip:
14
+ criterion_name = file[: file.find(".py")]
15
+ importlib.import_module(
16
+ "examples.speech_recognition.criterions." + criterion_name
17
+ )
fairseq/examples/speech_recognition/criterions/cross_entropy_acc.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import absolute_import, division, print_function, unicode_literals
7
+
8
+ import logging
9
+ import math
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from fairseq import utils
14
+ from fairseq.criterions import FairseqCriterion, register_criterion
15
+
16
+
17
+ @register_criterion("cross_entropy_acc")
18
+ class CrossEntropyWithAccCriterion(FairseqCriterion):
19
+ def __init__(self, task, sentence_avg):
20
+ super().__init__(task)
21
+ self.sentence_avg = sentence_avg
22
+
23
+ def compute_loss(self, model, net_output, target, reduction, log_probs):
24
+ # N, T -> N * T
25
+ target = target.view(-1)
26
+ lprobs = model.get_normalized_probs(net_output, log_probs=log_probs)
27
+ if not hasattr(lprobs, "batch_first"):
28
+ logging.warning(
29
+ "ERROR: we need to know whether "
30
+ "batch first for the net output; "
31
+ "you need to set batch_first attribute for the return value of "
32
+ "model.get_normalized_probs. Now, we assume this is true, but "
33
+ "in the future, we will raise exception instead. "
34
+ )
35
+ batch_first = getattr(lprobs, "batch_first", True)
36
+ if not batch_first:
37
+ lprobs = lprobs.transpose(0, 1)
38
+
39
+ # N, T, D -> N * T, D
40
+ lprobs = lprobs.view(-1, lprobs.size(-1))
41
+ loss = F.nll_loss(
42
+ lprobs, target, ignore_index=self.padding_idx, reduction=reduction
43
+ )
44
+ return lprobs, loss
45
+
46
+ def get_logging_output(self, sample, target, lprobs, loss):
47
+ target = target.view(-1)
48
+ mask = target != self.padding_idx
49
+ correct = torch.sum(
50
+ lprobs.argmax(1).masked_select(mask) == target.masked_select(mask)
51
+ )
52
+ total = torch.sum(mask)
53
+ sample_size = (
54
+ sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
55
+ )
56
+
57
+ logging_output = {
58
+ "loss": utils.item(loss.data), # * sample['ntokens'],
59
+ "ntokens": sample["ntokens"],
60
+ "nsentences": sample["target"].size(0),
61
+ "sample_size": sample_size,
62
+ "correct": utils.item(correct.data),
63
+ "total": utils.item(total.data),
64
+ "nframes": torch.sum(sample["net_input"]["src_lengths"]).item(),
65
+ }
66
+
67
+ return sample_size, logging_output
68
+
69
+ def forward(self, model, sample, reduction="sum", log_probs=True):
70
+ """Computes the cross entropy with accuracy metric for the given sample.
71
+
72
+ This is similar to CrossEntropyCriterion in fairseq, but also
73
+ computes accuracy metrics as part of logging
74
+
75
+ Args:
76
+ logprobs (Torch.tensor) of shape N, T, D i.e.
77
+ batchsize, timesteps, dimensions
78
+ targets (Torch.tensor) of shape N, T i.e batchsize, timesteps
79
+
80
+ Returns:
81
+ tuple: With three elements:
82
+ 1) the loss
83
+ 2) the sample size, which is used as the denominator for the gradient
84
+ 3) logging outputs to display while training
85
+
86
+ TODO:
87
+ * Currently this Criterion will only work with LSTMEncoderModels or
88
+ FairseqModels which have decoder, or Models which return TorchTensor
89
+ as net_output.
90
+ We need to make a change to support all FairseqEncoder models.
91
+ """
92
+ net_output = model(**sample["net_input"])
93
+ target = model.get_targets(sample, net_output)
94
+ lprobs, loss = self.compute_loss(
95
+ model, net_output, target, reduction, log_probs
96
+ )
97
+ sample_size, logging_output = self.get_logging_output(
98
+ sample, target, lprobs, loss
99
+ )
100
+ return loss, sample_size, logging_output
101
+
102
+ @staticmethod
103
+ def aggregate_logging_outputs(logging_outputs):
104
+ """Aggregate logging outputs from data parallel training."""
105
+ correct_sum = sum(log.get("correct", 0) for log in logging_outputs)
106
+ total_sum = sum(log.get("total", 0) for log in logging_outputs)
107
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
108
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
109
+ nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
110
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
111
+ nframes = sum(log.get("nframes", 0) for log in logging_outputs)
112
+ agg_output = {
113
+ "loss": loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0,
114
+ # if args.sentence_avg, then sample_size is nsentences, then loss
115
+ # is per-sentence loss; else sample_size is ntokens, the loss
116
+ # becomes per-output token loss
117
+ "ntokens": ntokens,
118
+ "nsentences": nsentences,
119
+ "nframes": nframes,
120
+ "sample_size": sample_size,
121
+ "acc": correct_sum * 100.0 / total_sum if total_sum > 0 else 0.0,
122
+ "correct": correct_sum,
123
+ "total": total_sum,
124
+ # total is the number of validate tokens
125
+ }
126
+ if sample_size != ntokens:
127
+ agg_output["nll_loss"] = loss_sum / ntokens / math.log(2)
128
+ # loss: per output token loss
129
+ # nll_loss: per sentence loss
130
+ return agg_output
fairseq/examples/speech_recognition/data/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .asr_dataset import AsrDataset
7
+
8
+
9
+ __all__ = [
10
+ "AsrDataset",
11
+ ]
fairseq/examples/speech_recognition/data/asr_dataset.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+
8
+ import numpy as np
9
+ from fairseq.data import FairseqDataset
10
+
11
+ from . import data_utils
12
+ from .collaters import Seq2SeqCollater
13
+
14
+
15
+ class AsrDataset(FairseqDataset):
16
+ """
17
+ A dataset representing speech and corresponding transcription.
18
+
19
+ Args:
20
+ aud_paths: (List[str]): A list of str with paths to audio files.
21
+ aud_durations_ms (List[int]): A list of int containing the durations of
22
+ audio files.
23
+ tgt (List[torch.LongTensor]): A list of LongTensors containing the indices
24
+ of target transcriptions.
25
+ tgt_dict (~fairseq.data.Dictionary): target vocabulary.
26
+ ids (List[str]): A list of utterance IDs.
27
+ speakers (List[str]): A list of speakers corresponding to utterances.
28
+ num_mel_bins (int): Number of triangular mel-frequency bins (default: 80)
29
+ frame_length (float): Frame length in milliseconds (default: 25.0)
30
+ frame_shift (float): Frame shift in milliseconds (default: 10.0)
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ aud_paths,
36
+ aud_durations_ms,
37
+ tgt,
38
+ tgt_dict,
39
+ ids,
40
+ speakers,
41
+ num_mel_bins=80,
42
+ frame_length=25.0,
43
+ frame_shift=10.0,
44
+ ):
45
+ assert frame_length > 0
46
+ assert frame_shift > 0
47
+ assert all(x > frame_length for x in aud_durations_ms)
48
+ self.frame_sizes = [
49
+ int(1 + (d - frame_length) / frame_shift) for d in aud_durations_ms
50
+ ]
51
+
52
+ assert len(aud_paths) > 0
53
+ assert len(aud_paths) == len(aud_durations_ms)
54
+ assert len(aud_paths) == len(tgt)
55
+ assert len(aud_paths) == len(ids)
56
+ assert len(aud_paths) == len(speakers)
57
+ self.aud_paths = aud_paths
58
+ self.tgt_dict = tgt_dict
59
+ self.tgt = tgt
60
+ self.ids = ids
61
+ self.speakers = speakers
62
+ self.num_mel_bins = num_mel_bins
63
+ self.frame_length = frame_length
64
+ self.frame_shift = frame_shift
65
+
66
+ self.s2s_collater = Seq2SeqCollater(
67
+ 0,
68
+ 1,
69
+ pad_index=self.tgt_dict.pad(),
70
+ eos_index=self.tgt_dict.eos(),
71
+ move_eos_to_beginning=True,
72
+ )
73
+
74
+ def __getitem__(self, index):
75
+ import torchaudio
76
+ import torchaudio.compliance.kaldi as kaldi
77
+
78
+ tgt_item = self.tgt[index] if self.tgt is not None else None
79
+
80
+ path = self.aud_paths[index]
81
+ if not os.path.exists(path):
82
+ raise FileNotFoundError("Audio file not found: {}".format(path))
83
+ sound, sample_rate = torchaudio.load_wav(path)
84
+ output = kaldi.fbank(
85
+ sound,
86
+ num_mel_bins=self.num_mel_bins,
87
+ frame_length=self.frame_length,
88
+ frame_shift=self.frame_shift,
89
+ )
90
+ output_cmvn = data_utils.apply_mv_norm(output)
91
+
92
+ return {"id": index, "data": [output_cmvn.detach(), tgt_item]}
93
+
94
+ def __len__(self):
95
+ return len(self.aud_paths)
96
+
97
+ def collater(self, samples):
98
+ """Merge a list of samples to form a mini-batch.
99
+
100
+ Args:
101
+ samples (List[int]): sample indices to collate
102
+
103
+ Returns:
104
+ dict: a mini-batch suitable for forwarding with a Model
105
+ """
106
+ return self.s2s_collater.collate(samples)
107
+
108
+ def num_tokens(self, index):
109
+ return self.frame_sizes[index]
110
+
111
+ def size(self, index):
112
+ """Return an example's size as a float or tuple. This value is used when
113
+ filtering a dataset with ``--max-positions``."""
114
+ return (
115
+ self.frame_sizes[index],
116
+ len(self.tgt[index]) if self.tgt is not None else 0,
117
+ )
118
+
119
+ def ordered_indices(self):
120
+ """Return an ordered list of indices. Batches will be constructed based
121
+ on this order."""
122
+ return np.arange(len(self))
fairseq/examples/speech_recognition/data/collaters.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """
6
+ This module contains collection of classes which implement
7
+ collate functionalities for various tasks.
8
+
9
+ Collaters should know what data to expect for each sample
10
+ and they should pack / collate them into batches
11
+ """
12
+
13
+
14
+ from __future__ import absolute_import, division, print_function, unicode_literals
15
+
16
+ import numpy as np
17
+ import torch
18
+ from fairseq.data import data_utils as fairseq_data_utils
19
+
20
+
21
+ class Seq2SeqCollater(object):
22
+ """
23
+ Implements collate function mainly for seq2seq tasks
24
+ This expects each sample to contain feature (src_tokens) and
25
+ targets.
26
+ This collator is also used for aligned training task.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ feature_index=0,
32
+ label_index=1,
33
+ pad_index=1,
34
+ eos_index=2,
35
+ move_eos_to_beginning=True,
36
+ ):
37
+ self.feature_index = feature_index
38
+ self.label_index = label_index
39
+ self.pad_index = pad_index
40
+ self.eos_index = eos_index
41
+ self.move_eos_to_beginning = move_eos_to_beginning
42
+
43
+ def _collate_frames(self, frames):
44
+ """Convert a list of 2d frames into a padded 3d tensor
45
+ Args:
46
+ frames (list): list of 2d frames of size L[i]*f_dim. Where L[i] is
47
+ length of i-th frame and f_dim is static dimension of features
48
+ Returns:
49
+ 3d tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
50
+ """
51
+ len_max = max(frame.size(0) for frame in frames)
52
+ f_dim = frames[0].size(1)
53
+ res = frames[0].new(len(frames), len_max, f_dim).fill_(0.0)
54
+
55
+ for i, v in enumerate(frames):
56
+ res[i, : v.size(0)] = v
57
+
58
+ return res
59
+
60
+ def collate(self, samples):
61
+ """
62
+ utility function to collate samples into batch for speech recognition.
63
+ """
64
+ if len(samples) == 0:
65
+ return {}
66
+
67
+ # parse samples into torch tensors
68
+ parsed_samples = []
69
+ for s in samples:
70
+ # skip invalid samples
71
+ if s["data"][self.feature_index] is None:
72
+ continue
73
+ source = s["data"][self.feature_index]
74
+ if isinstance(source, (np.ndarray, np.generic)):
75
+ source = torch.from_numpy(source)
76
+ target = s["data"][self.label_index]
77
+ if isinstance(target, (np.ndarray, np.generic)):
78
+ target = torch.from_numpy(target).long()
79
+ elif isinstance(target, list):
80
+ target = torch.LongTensor(target)
81
+
82
+ parsed_sample = {"id": s["id"], "source": source, "target": target}
83
+ parsed_samples.append(parsed_sample)
84
+ samples = parsed_samples
85
+
86
+ id = torch.LongTensor([s["id"] for s in samples])
87
+ frames = self._collate_frames([s["source"] for s in samples])
88
+ # sort samples by descending number of frames
89
+ frames_lengths = torch.LongTensor([s["source"].size(0) for s in samples])
90
+ frames_lengths, sort_order = frames_lengths.sort(descending=True)
91
+ id = id.index_select(0, sort_order)
92
+ frames = frames.index_select(0, sort_order)
93
+
94
+ target = None
95
+ target_lengths = None
96
+ prev_output_tokens = None
97
+ if samples[0].get("target", None) is not None:
98
+ ntokens = sum(len(s["target"]) for s in samples)
99
+ target = fairseq_data_utils.collate_tokens(
100
+ [s["target"] for s in samples],
101
+ self.pad_index,
102
+ self.eos_index,
103
+ left_pad=False,
104
+ move_eos_to_beginning=False,
105
+ )
106
+ target = target.index_select(0, sort_order)
107
+ target_lengths = torch.LongTensor(
108
+ [s["target"].size(0) for s in samples]
109
+ ).index_select(0, sort_order)
110
+ prev_output_tokens = fairseq_data_utils.collate_tokens(
111
+ [s["target"] for s in samples],
112
+ self.pad_index,
113
+ self.eos_index,
114
+ left_pad=False,
115
+ move_eos_to_beginning=self.move_eos_to_beginning,
116
+ )
117
+ prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
118
+ else:
119
+ ntokens = sum(len(s["source"]) for s in samples)
120
+
121
+ batch = {
122
+ "id": id,
123
+ "ntokens": ntokens,
124
+ "net_input": {"src_tokens": frames, "src_lengths": frames_lengths},
125
+ "target": target,
126
+ "target_lengths": target_lengths,
127
+ "nsentences": len(samples),
128
+ }
129
+ if prev_output_tokens is not None:
130
+ batch["net_input"]["prev_output_tokens"] = prev_output_tokens
131
+ return batch
fairseq/examples/speech_recognition/data/data_utils.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+
9
+ def calc_mean_invstddev(feature):
10
+ if len(feature.size()) != 2:
11
+ raise ValueError("We expect the input feature to be 2-D tensor")
12
+ mean = feature.mean(0)
13
+ var = feature.var(0)
14
+ # avoid division by ~zero
15
+ eps = 1e-8
16
+ if (var < eps).any():
17
+ return mean, 1.0 / (torch.sqrt(var) + eps)
18
+ return mean, 1.0 / torch.sqrt(var)
19
+
20
+
21
+ def apply_mv_norm(features):
22
+ # If there is less than 2 spectrograms, the variance cannot be computed (is NaN)
23
+ # and normalization is not possible, so return the item as it is
24
+ if features.size(0) < 2:
25
+ return features
26
+ mean, invstddev = calc_mean_invstddev(features)
27
+ res = (features - mean) * invstddev
28
+ return res
29
+
30
+
31
+ def lengths_to_encoder_padding_mask(lengths, batch_first=False):
32
+ """
33
+ convert lengths (a 1-D Long/Int tensor) to 2-D binary tensor
34
+
35
+ Args:
36
+ lengths: a (B, )-shaped tensor
37
+
38
+ Return:
39
+ max_length: maximum length of B sequences
40
+ encoder_padding_mask: a (max_length, B) binary mask, where
41
+ [t, b] = 0 for t < lengths[b] and 1 otherwise
42
+
43
+ TODO:
44
+ kernelize this function if benchmarking shows this function is slow
45
+ """
46
+ max_lengths = torch.max(lengths).item()
47
+ bsz = lengths.size(0)
48
+ encoder_padding_mask = torch.arange(
49
+ max_lengths
50
+ ).to( # a (T, ) tensor with [0, ..., T-1]
51
+ lengths.device
52
+ ).view( # move to the right device
53
+ 1, max_lengths
54
+ ).expand( # reshape to (1, T)-shaped tensor
55
+ bsz, -1
56
+ ) >= lengths.view( # expand to (B, T)-shaped tensor
57
+ bsz, 1
58
+ ).expand(
59
+ -1, max_lengths
60
+ )
61
+ if not batch_first:
62
+ return encoder_padding_mask.t(), max_lengths
63
+ else:
64
+ return encoder_padding_mask, max_lengths
65
+
66
+
67
+ def encoder_padding_mask_to_lengths(
68
+ encoder_padding_mask, max_lengths, batch_size, device
69
+ ):
70
+ """
71
+ convert encoder_padding_mask (2-D binary tensor) to a 1-D tensor
72
+
73
+ Conventionally, encoder output contains a encoder_padding_mask, which is
74
+ a 2-D mask in a shape (T, B), whose (t, b) element indicate whether
75
+ encoder_out[t, b] is a valid output (=0) or not (=1). Occasionally, we
76
+ need to convert this mask tensor to a 1-D tensor in shape (B, ), where
77
+ [b] denotes the valid length of b-th sequence
78
+
79
+ Args:
80
+ encoder_padding_mask: a (T, B)-shaped binary tensor or None; if None,
81
+ indicating all are valid
82
+ Return:
83
+ seq_lengths: a (B,)-shaped tensor, where its (b, )-th element is the
84
+ number of valid elements of b-th sequence
85
+
86
+ max_lengths: maximum length of all sequence, if encoder_padding_mask is
87
+ not None, max_lengths must equal to encoder_padding_mask.size(0)
88
+
89
+ batch_size: batch size; if encoder_padding_mask is
90
+ not None, max_lengths must equal to encoder_padding_mask.size(1)
91
+
92
+ device: which device to put the result on
93
+ """
94
+ if encoder_padding_mask is None:
95
+ return torch.Tensor([max_lengths] * batch_size).to(torch.int32).to(device)
96
+
97
+ assert encoder_padding_mask.size(0) == max_lengths, "max_lengths does not match"
98
+ assert encoder_padding_mask.size(1) == batch_size, "batch_size does not match"
99
+
100
+ return max_lengths - torch.sum(encoder_padding_mask, dim=0)
fairseq/examples/speech_recognition/data/replabels.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ """
9
+ Replabel transforms for use with flashlight's ASG criterion.
10
+ """
11
+
12
+
13
+ def replabel_symbol(i):
14
+ """
15
+ Replabel symbols used in flashlight, currently just "1", "2", ...
16
+ This prevents training with numeral tokens, so this might change in the future
17
+ """
18
+ return str(i)
19
+
20
+
21
+ def pack_replabels(tokens, dictionary, max_reps):
22
+ """
23
+ Pack a token sequence so that repeated symbols are replaced by replabels
24
+ """
25
+ if len(tokens) == 0 or max_reps <= 0:
26
+ return tokens
27
+
28
+ replabel_value_to_idx = [0] * (max_reps + 1)
29
+ for i in range(1, max_reps + 1):
30
+ replabel_value_to_idx[i] = dictionary.index(replabel_symbol(i))
31
+
32
+ result = []
33
+ prev_token = -1
34
+ num_reps = 0
35
+ for token in tokens:
36
+ if token == prev_token and num_reps < max_reps:
37
+ num_reps += 1
38
+ else:
39
+ if num_reps > 0:
40
+ result.append(replabel_value_to_idx[num_reps])
41
+ num_reps = 0
42
+ result.append(token)
43
+ prev_token = token
44
+ if num_reps > 0:
45
+ result.append(replabel_value_to_idx[num_reps])
46
+ return result
47
+
48
+
49
+ def unpack_replabels(tokens, dictionary, max_reps):
50
+ """
51
+ Unpack a token sequence so that replabels are replaced by repeated symbols
52
+ """
53
+ if len(tokens) == 0 or max_reps <= 0:
54
+ return tokens
55
+
56
+ replabel_idx_to_value = {}
57
+ for i in range(1, max_reps + 1):
58
+ replabel_idx_to_value[dictionary.index(replabel_symbol(i))] = i
59
+
60
+ result = []
61
+ prev_token = -1
62
+ for token in tokens:
63
+ try:
64
+ for _ in range(replabel_idx_to_value[token]):
65
+ result.append(prev_token)
66
+ prev_token = -1
67
+ except KeyError:
68
+ result.append(token)
69
+ prev_token = token
70
+ return result
fairseq/examples/speech_recognition/datasets/asr_prep_json.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from __future__ import absolute_import, division, print_function, unicode_literals
8
+
9
+ import argparse
10
+ import concurrent.futures
11
+ import json
12
+ import multiprocessing
13
+ import os
14
+ from collections import namedtuple
15
+ from itertools import chain
16
+
17
+ import sentencepiece as spm
18
+ from fairseq.data import Dictionary
19
+
20
+
21
+ MILLISECONDS_TO_SECONDS = 0.001
22
+
23
+
24
+ def process_sample(aud_path, lable, utt_id, sp, tgt_dict):
25
+ import torchaudio
26
+
27
+ input = {}
28
+ output = {}
29
+ si, ei = torchaudio.info(aud_path)
30
+ input["length_ms"] = int(
31
+ si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS
32
+ )
33
+ input["path"] = aud_path
34
+
35
+ token = " ".join(sp.EncodeAsPieces(lable))
36
+ ids = tgt_dict.encode_line(token, append_eos=False)
37
+ output["text"] = lable
38
+ output["token"] = token
39
+ output["tokenid"] = ", ".join(map(str, [t.tolist() for t in ids]))
40
+ return {utt_id: {"input": input, "output": output}}
41
+
42
+
43
+ def main():
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument(
46
+ "--audio-dirs",
47
+ nargs="+",
48
+ default=["-"],
49
+ required=True,
50
+ help="input directories with audio files",
51
+ )
52
+ parser.add_argument(
53
+ "--labels",
54
+ required=True,
55
+ help="aggregated input labels with format <ID LABEL> per line",
56
+ type=argparse.FileType("r", encoding="UTF-8"),
57
+ )
58
+ parser.add_argument(
59
+ "--spm-model",
60
+ required=True,
61
+ help="sentencepiece model to use for encoding",
62
+ type=argparse.FileType("r", encoding="UTF-8"),
63
+ )
64
+ parser.add_argument(
65
+ "--dictionary",
66
+ required=True,
67
+ help="file to load fairseq dictionary from",
68
+ type=argparse.FileType("r", encoding="UTF-8"),
69
+ )
70
+ parser.add_argument("--audio-format", choices=["flac", "wav"], default="wav")
71
+ parser.add_argument(
72
+ "--output",
73
+ required=True,
74
+ type=argparse.FileType("w"),
75
+ help="path to save json output",
76
+ )
77
+ args = parser.parse_args()
78
+
79
+ sp = spm.SentencePieceProcessor()
80
+ sp.Load(args.spm_model.name)
81
+
82
+ tgt_dict = Dictionary.load(args.dictionary)
83
+
84
+ labels = {}
85
+ for line in args.labels:
86
+ (utt_id, label) = line.split(" ", 1)
87
+ labels[utt_id] = label
88
+ if len(labels) == 0:
89
+ raise Exception("No labels found in ", args.labels_path)
90
+
91
+ Sample = namedtuple("Sample", "aud_path utt_id")
92
+ samples = []
93
+ for path, _, files in chain.from_iterable(
94
+ os.walk(path) for path in args.audio_dirs
95
+ ):
96
+ for f in files:
97
+ if f.endswith(args.audio_format):
98
+ if len(os.path.splitext(f)) != 2:
99
+ raise Exception("Expect <utt_id.extension> file name. Got: ", f)
100
+ utt_id = os.path.splitext(f)[0]
101
+ if utt_id not in labels:
102
+ continue
103
+ samples.append(Sample(os.path.join(path, f), utt_id))
104
+
105
+ utts = {}
106
+ num_cpu = multiprocessing.cpu_count()
107
+ with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpu) as executor:
108
+ future_to_sample = {
109
+ executor.submit(
110
+ process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict
111
+ ): s
112
+ for s in samples
113
+ }
114
+ for future in concurrent.futures.as_completed(future_to_sample):
115
+ try:
116
+ data = future.result()
117
+ except Exception as exc:
118
+ print("generated an exception: ", exc)
119
+ else:
120
+ utts.update(data)
121
+ json.dump({"utts": utts}, args.output, indent=4)
122
+
123
+
124
+ if __name__ == "__main__":
125
+ main()
fairseq/examples/speech_recognition/datasets/prepare-librispeech.sh ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Prepare librispeech dataset
8
+
9
+ base_url=www.openslr.org/resources/12
10
+ train_dir=train_960
11
+
12
+ if [ "$#" -ne 2 ]; then
13
+ echo "Usage: $0 <download_dir> <out_dir>"
14
+ echo "e.g.: $0 /tmp/librispeech_raw/ ~/data/librispeech_final"
15
+ exit 1
16
+ fi
17
+
18
+ download_dir=${1%/}
19
+ out_dir=${2%/}
20
+
21
+ fairseq_root=~/fairseq-py/
22
+ mkdir -p ${out_dir}
23
+ cd ${out_dir} || exit
24
+
25
+ nbpe=5000
26
+ bpemode=unigram
27
+
28
+ if [ ! -d "$fairseq_root" ]; then
29
+ echo "$0: Please set correct fairseq_root"
30
+ exit 1
31
+ fi
32
+
33
+ echo "Data Download"
34
+ for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
35
+ url=$base_url/$part.tar.gz
36
+ if ! wget -P $download_dir $url; then
37
+ echo "$0: wget failed for $url"
38
+ exit 1
39
+ fi
40
+ if ! tar -C $download_dir -xvzf $download_dir/$part.tar.gz; then
41
+ echo "$0: error un-tarring archive $download_dir/$part.tar.gz"
42
+ exit 1
43
+ fi
44
+ done
45
+
46
+ echo "Merge all train packs into one"
47
+ mkdir -p ${download_dir}/LibriSpeech/${train_dir}/
48
+ for part in train-clean-100 train-clean-360 train-other-500; do
49
+ mv ${download_dir}/LibriSpeech/${part}/* $download_dir/LibriSpeech/${train_dir}/
50
+ done
51
+ echo "Merge train text"
52
+ find ${download_dir}/LibriSpeech/${train_dir}/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/${train_dir}/text
53
+
54
+ # Use combined dev-clean and dev-other as validation set
55
+ find ${download_dir}/LibriSpeech/dev-clean/ ${download_dir}/LibriSpeech/dev-other/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/valid_text
56
+ find ${download_dir}/LibriSpeech/test-clean/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/test-clean/text
57
+ find ${download_dir}/LibriSpeech/test-other/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/test-other/text
58
+
59
+
60
+ dict=data/lang_char/${train_dir}_${bpemode}${nbpe}_units.txt
61
+ encoded=data/lang_char/${train_dir}_${bpemode}${nbpe}_encoded.txt
62
+ fairseq_dict=data/lang_char/${train_dir}_${bpemode}${nbpe}_fairseq_dict.txt
63
+ bpemodel=data/lang_char/${train_dir}_${bpemode}${nbpe}
64
+ echo "dictionary: ${dict}"
65
+ echo "Dictionary preparation"
66
+ mkdir -p data/lang_char/
67
+ echo "<unk> 3" > ${dict}
68
+ echo "</s> 2" >> ${dict}
69
+ echo "<pad> 1" >> ${dict}
70
+ cut -f 2- -d" " ${download_dir}/LibriSpeech/${train_dir}/text > data/lang_char/input.txt
71
+ spm_train --input=data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000 --unk_id=3 --eos_id=2 --pad_id=1 --bos_id=-1 --character_coverage=1
72
+ spm_encode --model=${bpemodel}.model --output_format=piece < data/lang_char/input.txt > ${encoded}
73
+ cat ${encoded} | tr ' ' '\n' | sort | uniq | awk '{print $0 " " NR+3}' >> ${dict}
74
+ cat ${encoded} | tr ' ' '\n' | sort | uniq -c | awk '{print $2 " " $1}' > ${fairseq_dict}
75
+ wc -l ${dict}
76
+
77
+ echo "Prepare train and test jsons"
78
+ for part in train_960 test-other test-clean; do
79
+ python ${fairseq_root}/examples/speech_recognition/datasets/asr_prep_json.py --audio-dirs ${download_dir}/LibriSpeech/${part} --labels ${download_dir}/LibriSpeech/${part}/text --spm-model ${bpemodel}.model --audio-format flac --dictionary ${fairseq_dict} --output ${part}.json
80
+ done
81
+ # fairseq expects to find train.json and valid.json during training
82
+ mv train_960.json train.json
83
+
84
+ echo "Prepare valid json"
85
+ python ${fairseq_root}/examples/speech_recognition/datasets/asr_prep_json.py --audio-dirs ${download_dir}/LibriSpeech/dev-clean ${download_dir}/LibriSpeech/dev-other --labels ${download_dir}/LibriSpeech/valid_text --spm-model ${bpemodel}.model --audio-format flac --dictionary ${fairseq_dict} --output valid.json
86
+
87
+ cp ${fairseq_dict} ./dict.txt
88
+ cp ${bpemodel}.model ./spm.model
fairseq/examples/speech_recognition/infer.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Run inference for pre-processed data with a trained model.
9
+ """
10
+
11
+ import ast
12
+ import logging
13
+ import math
14
+ import os
15
+ import sys
16
+
17
+ import editdistance
18
+ import numpy as np
19
+ import torch
20
+ from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
21
+ from fairseq.data.data_utils import post_process
22
+ from fairseq.logging.meters import StopwatchMeter, TimeMeter
23
+
24
+
25
+ logging.basicConfig()
26
+ logging.root.setLevel(logging.INFO)
27
+ logging.basicConfig(level=logging.INFO)
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ def add_asr_eval_argument(parser):
32
+ parser.add_argument("--kspmodel", default=None, help="sentence piece model")
33
+ parser.add_argument(
34
+ "--wfstlm", default=None, help="wfstlm on dictonary output units"
35
+ )
36
+ parser.add_argument(
37
+ "--rnnt_decoding_type",
38
+ default="greedy",
39
+ help="wfstlm on dictonary\
40
+ output units",
41
+ )
42
+ try:
43
+ parser.add_argument(
44
+ "--lm-weight",
45
+ "--lm_weight",
46
+ type=float,
47
+ default=0.2,
48
+ help="weight for lm while interpolating with neural score",
49
+ )
50
+ except:
51
+ pass
52
+ parser.add_argument(
53
+ "--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
54
+ )
55
+ parser.add_argument(
56
+ "--w2l-decoder",
57
+ choices=["viterbi", "kenlm", "fairseqlm"],
58
+ help="use a w2l decoder",
59
+ )
60
+ parser.add_argument("--lexicon", help="lexicon for w2l decoder")
61
+ parser.add_argument("--unit-lm", action="store_true", help="if using a unit lm")
62
+ parser.add_argument("--kenlm-model", "--lm-model", help="lm model for w2l decoder")
63
+ parser.add_argument("--beam-threshold", type=float, default=25.0)
64
+ parser.add_argument("--beam-size-token", type=float, default=100)
65
+ parser.add_argument("--word-score", type=float, default=1.0)
66
+ parser.add_argument("--unk-weight", type=float, default=-math.inf)
67
+ parser.add_argument("--sil-weight", type=float, default=0.0)
68
+ parser.add_argument(
69
+ "--dump-emissions",
70
+ type=str,
71
+ default=None,
72
+ help="if present, dumps emissions into this file and exits",
73
+ )
74
+ parser.add_argument(
75
+ "--dump-features",
76
+ type=str,
77
+ default=None,
78
+ help="if present, dumps features into this file and exits",
79
+ )
80
+ parser.add_argument(
81
+ "--load-emissions",
82
+ type=str,
83
+ default=None,
84
+ help="if present, loads emissions from this file",
85
+ )
86
+ return parser
87
+
88
+
89
+ def check_args(args):
90
+ # assert args.path is not None, "--path required for generation!"
91
+ # assert args.results_path is not None, "--results_path required for generation!"
92
+ assert (
93
+ not args.sampling or args.nbest == args.beam
94
+ ), "--sampling requires --nbest to be equal to --beam"
95
+ assert (
96
+ args.replace_unk is None or args.raw_text
97
+ ), "--replace-unk requires a raw text dataset (--raw-text)"
98
+
99
+
100
+ def get_dataset_itr(args, task, models):
101
+ return task.get_batch_iterator(
102
+ dataset=task.dataset(args.gen_subset),
103
+ max_tokens=args.max_tokens,
104
+ max_sentences=args.batch_size,
105
+ max_positions=(sys.maxsize, sys.maxsize),
106
+ ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
107
+ required_batch_size_multiple=args.required_batch_size_multiple,
108
+ num_shards=args.num_shards,
109
+ shard_id=args.shard_id,
110
+ num_workers=args.num_workers,
111
+ data_buffer_size=args.data_buffer_size,
112
+ ).next_epoch_itr(shuffle=False)
113
+
114
+
115
+ def process_predictions(
116
+ args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id
117
+ ):
118
+ for hypo in hypos[: min(len(hypos), args.nbest)]:
119
+ hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
120
+
121
+ if "words" in hypo:
122
+ hyp_words = " ".join(hypo["words"])
123
+ else:
124
+ hyp_words = post_process(hyp_pieces, args.post_process)
125
+
126
+ if res_files is not None:
127
+ print(
128
+ "{} ({}-{})".format(hyp_pieces, speaker, id),
129
+ file=res_files["hypo.units"],
130
+ )
131
+ print(
132
+ "{} ({}-{})".format(hyp_words, speaker, id),
133
+ file=res_files["hypo.words"],
134
+ )
135
+
136
+ tgt_pieces = tgt_dict.string(target_tokens)
137
+ tgt_words = post_process(tgt_pieces, args.post_process)
138
+
139
+ if res_files is not None:
140
+ print(
141
+ "{} ({}-{})".format(tgt_pieces, speaker, id),
142
+ file=res_files["ref.units"],
143
+ )
144
+ print(
145
+ "{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"]
146
+ )
147
+
148
+ if not args.quiet:
149
+ logger.info("HYPO:" + hyp_words)
150
+ logger.info("TARGET:" + tgt_words)
151
+ logger.info("___________________")
152
+
153
+ hyp_words = hyp_words.split()
154
+ tgt_words = tgt_words.split()
155
+ return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
156
+
157
+
158
+ def prepare_result_files(args):
159
+ def get_res_file(file_prefix):
160
+ if args.num_shards > 1:
161
+ file_prefix = f"{args.shard_id}_{file_prefix}"
162
+ path = os.path.join(
163
+ args.results_path,
164
+ "{}-{}-{}.txt".format(
165
+ file_prefix, os.path.basename(args.path), args.gen_subset
166
+ ),
167
+ )
168
+ return open(path, "w", buffering=1)
169
+
170
+ if not args.results_path:
171
+ return None
172
+
173
+ return {
174
+ "hypo.words": get_res_file("hypo.word"),
175
+ "hypo.units": get_res_file("hypo.units"),
176
+ "ref.words": get_res_file("ref.word"),
177
+ "ref.units": get_res_file("ref.units"),
178
+ }
179
+
180
+
181
+ def optimize_models(args, use_cuda, models):
182
+ """Optimize ensemble for generation"""
183
+ for model in models:
184
+ model.make_generation_fast_(
185
+ beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
186
+ need_attn=args.print_alignment,
187
+ )
188
+ if args.fp16:
189
+ model.half()
190
+ if use_cuda:
191
+ model.cuda()
192
+
193
+
194
+ def apply_half(t):
195
+ if t.dtype is torch.float32:
196
+ return t.to(dtype=torch.half)
197
+ return t
198
+
199
+
200
+ class ExistingEmissionsDecoder(object):
201
+ def __init__(self, decoder, emissions):
202
+ self.decoder = decoder
203
+ self.emissions = emissions
204
+
205
+ def generate(self, models, sample, **unused):
206
+ ids = sample["id"].cpu().numpy()
207
+ try:
208
+ emissions = np.stack(self.emissions[ids])
209
+ except:
210
+ print([x.shape for x in self.emissions[ids]])
211
+ raise Exception("invalid sizes")
212
+ emissions = torch.from_numpy(emissions)
213
+ return self.decoder.decode(emissions)
214
+
215
+
216
+ def main(args, task=None, model_state=None):
217
+ check_args(args)
218
+
219
+ use_fp16 = args.fp16
220
+ if args.max_tokens is None and args.batch_size is None:
221
+ args.max_tokens = 4000000
222
+ logger.info(args)
223
+
224
+ use_cuda = torch.cuda.is_available() and not args.cpu
225
+
226
+ logger.info("| decoding with criterion {}".format(args.criterion))
227
+
228
+ task = tasks.setup_task(args)
229
+
230
+ # Load ensemble
231
+ if args.load_emissions:
232
+ models, criterions = [], []
233
+ task.load_dataset(args.gen_subset)
234
+ else:
235
+ logger.info("| loading model(s) from {}".format(args.path))
236
+ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
237
+ utils.split_paths(args.path, separator="\\"),
238
+ arg_overrides=ast.literal_eval(args.model_overrides),
239
+ task=task,
240
+ suffix=args.checkpoint_suffix,
241
+ strict=(args.checkpoint_shard_count == 1),
242
+ num_shards=args.checkpoint_shard_count,
243
+ state=model_state,
244
+ )
245
+ optimize_models(args, use_cuda, models)
246
+ task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task)
247
+
248
+
249
+ # Set dictionary
250
+ tgt_dict = task.target_dictionary
251
+
252
+ logger.info(
253
+ "| {} {} {} examples".format(
254
+ args.data, args.gen_subset, len(task.dataset(args.gen_subset))
255
+ )
256
+ )
257
+
258
+ # hack to pass transitions to W2lDecoder
259
+ if args.criterion == "asg_loss":
260
+ raise NotImplementedError("asg_loss is currently not supported")
261
+ # trans = criterions[0].asg.trans.data
262
+ # args.asg_transitions = torch.flatten(trans).tolist()
263
+
264
+ # Load dataset (possibly sharded)
265
+ itr = get_dataset_itr(args, task, models)
266
+
267
+ # Initialize generator
268
+ gen_timer = StopwatchMeter()
269
+
270
+ def build_generator(args):
271
+ w2l_decoder = getattr(args, "w2l_decoder", None)
272
+ if w2l_decoder == "viterbi":
273
+ from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
274
+
275
+ return W2lViterbiDecoder(args, task.target_dictionary)
276
+ elif w2l_decoder == "kenlm":
277
+ from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
278
+
279
+ return W2lKenLMDecoder(args, task.target_dictionary)
280
+ elif w2l_decoder == "fairseqlm":
281
+ from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder
282
+
283
+ return W2lFairseqLMDecoder(args, task.target_dictionary)
284
+ else:
285
+ print(
286
+ "only flashlight decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment"
287
+ )
288
+
289
+ # please do not touch this unless you test both generate.py and infer.py with audio_pretraining task
290
+ generator = build_generator(args)
291
+
292
+ if args.load_emissions:
293
+ generator = ExistingEmissionsDecoder(
294
+ generator, np.load(args.load_emissions, allow_pickle=True)
295
+ )
296
+ logger.info("loaded emissions from " + args.load_emissions)
297
+
298
+ num_sentences = 0
299
+
300
+ if args.results_path is not None and not os.path.exists(args.results_path):
301
+ os.makedirs(args.results_path)
302
+
303
+ max_source_pos = (
304
+ utils.resolve_max_positions(
305
+ task.max_positions(), *[model.max_positions() for model in models]
306
+ ),
307
+ )
308
+
309
+ if max_source_pos is not None:
310
+ max_source_pos = max_source_pos[0]
311
+ if max_source_pos is not None:
312
+ max_source_pos = max_source_pos[0] - 1
313
+
314
+ if args.dump_emissions:
315
+ emissions = {}
316
+ if args.dump_features:
317
+ features = {}
318
+ models[0].bert.proj = None
319
+ else:
320
+ res_files = prepare_result_files(args)
321
+ errs_t = 0
322
+ lengths_t = 0
323
+ with progress_bar.build_progress_bar(args, itr) as t:
324
+ wps_meter = TimeMeter()
325
+ for sample in t:
326
+ sample = utils.move_to_cuda(sample) if use_cuda else sample
327
+ if use_fp16:
328
+ sample = utils.apply_to_sample(apply_half, sample)
329
+ if "net_input" not in sample:
330
+ continue
331
+
332
+ prefix_tokens = None
333
+ if args.prefix_size > 0:
334
+ prefix_tokens = sample["target"][:, : args.prefix_size]
335
+
336
+ gen_timer.start()
337
+ if args.dump_emissions:
338
+ with torch.no_grad():
339
+ encoder_out = models[0](**sample["net_input"])
340
+ emm = models[0].get_normalized_probs(encoder_out, log_probs=True)
341
+ emm = emm.transpose(0, 1).cpu().numpy()
342
+ for i, id in enumerate(sample["id"]):
343
+ emissions[id.item()] = emm[i]
344
+ continue
345
+ elif args.dump_features:
346
+ with torch.no_grad():
347
+ encoder_out = models[0](**sample["net_input"])
348
+ feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy()
349
+ for i, id in enumerate(sample["id"]):
350
+ padding = (
351
+ encoder_out["encoder_padding_mask"][i].cpu().numpy()
352
+ if encoder_out["encoder_padding_mask"] is not None
353
+ else None
354
+ )
355
+ features[id.item()] = (feat[i], padding)
356
+ continue
357
+ hypos = task.inference_step(generator, models, sample, prefix_tokens)
358
+ num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
359
+ gen_timer.stop(num_generated_tokens)
360
+
361
+ for i, sample_id in enumerate(sample["id"].tolist()):
362
+ speaker = None
363
+ # id = task.dataset(args.gen_subset).ids[int(sample_id)]
364
+ id = sample_id
365
+ toks = (
366
+ sample["target"][i, :]
367
+ if "target_label" not in sample
368
+ else sample["target_label"][i, :]
369
+ )
370
+ target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
371
+ # Process top predictions
372
+ errs, length = process_predictions(
373
+ args,
374
+ hypos[i],
375
+ None,
376
+ tgt_dict,
377
+ target_tokens,
378
+ res_files,
379
+ speaker,
380
+ id,
381
+ )
382
+ errs_t += errs
383
+ lengths_t += length
384
+
385
+ wps_meter.update(num_generated_tokens)
386
+ t.log({"wps": round(wps_meter.avg)})
387
+ num_sentences += (
388
+ sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
389
+ )
390
+
391
+ wer = None
392
+ if args.dump_emissions:
393
+ emm_arr = []
394
+ for i in range(len(emissions)):
395
+ emm_arr.append(emissions[i])
396
+ np.save(args.dump_emissions, emm_arr)
397
+ logger.info(f"saved {len(emissions)} emissions to {args.dump_emissions}")
398
+ elif args.dump_features:
399
+ feat_arr = []
400
+ for i in range(len(features)):
401
+ feat_arr.append(features[i])
402
+ np.save(args.dump_features, feat_arr)
403
+ logger.info(f"saved {len(features)} emissions to {args.dump_features}")
404
+ else:
405
+ if lengths_t > 0:
406
+ wer = errs_t * 100.0 / lengths_t
407
+ logger.info(f"WER: {wer}")
408
+
409
+ logger.info(
410
+ "| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
411
+ "sentences/s, {:.2f} tokens/s)".format(
412
+ num_sentences,
413
+ gen_timer.n,
414
+ gen_timer.sum,
415
+ num_sentences / gen_timer.sum,
416
+ 1.0 / gen_timer.avg,
417
+ )
418
+ )
419
+ logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam))
420
+ return task, wer
421
+
422
+
423
+ def make_parser():
424
+ parser = options.get_generation_parser()
425
+ parser = add_asr_eval_argument(parser)
426
+ return parser
427
+
428
+
429
+ def cli_main():
430
+ parser = make_parser()
431
+ args = options.parse_args_and_arch(parser)
432
+ main(args)
433
+
434
+
435
+ if __name__ == "__main__":
436
+ cli_main()
fairseq/examples/speech_recognition/kaldi/__init__.py ADDED
File without changes
fairseq/examples/speech_recognition/kaldi/add-self-loop-simple.cc ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <iostream>
9
+ #include "fstext/fstext-lib.h" // @manual
10
+ #include "util/common-utils.h" // @manual
11
+
12
+ /*
13
+ * This program is to modify a FST without self-loop by:
14
+ * for each incoming arc with non-eps input symbol, add a self-loop arc
15
+ * with that non-eps symbol as input and eps as output.
16
+ *
17
+ * This is to make sure the resultant FST can do deduplication for repeated
18
+ * symbols, which is very common in acoustic model
19
+ *
20
+ */
21
+ namespace {
22
+ int32 AddSelfLoopsSimple(fst::StdVectorFst* fst) {
23
+ typedef fst::MutableArcIterator<fst::StdVectorFst> IterType;
24
+
25
+ int32 num_states_before = fst->NumStates();
26
+ fst::MakePrecedingInputSymbolsSame(false, fst);
27
+ int32 num_states_after = fst->NumStates();
28
+ KALDI_LOG << "There are " << num_states_before
29
+ << " states in the original FST; "
30
+ << " after MakePrecedingInputSymbolsSame, there are "
31
+ << num_states_after << " states " << std::endl;
32
+
33
+ auto weight_one = fst::StdArc::Weight::One();
34
+
35
+ int32 num_arc_added = 0;
36
+
37
+ fst::StdArc self_loop_arc;
38
+ self_loop_arc.weight = weight_one;
39
+
40
+ int32 num_states = fst->NumStates();
41
+ std::vector<std::set<int32>> incoming_non_eps_label_per_state(num_states);
42
+
43
+ for (int32 state = 0; state < num_states; state++) {
44
+ for (IterType aiter(fst, state); !aiter.Done(); aiter.Next()) {
45
+ fst::StdArc arc(aiter.Value());
46
+ if (arc.ilabel != 0) {
47
+ incoming_non_eps_label_per_state[arc.nextstate].insert(arc.ilabel);
48
+ }
49
+ }
50
+ }
51
+
52
+ for (int32 state = 0; state < num_states; state++) {
53
+ if (!incoming_non_eps_label_per_state[state].empty()) {
54
+ auto& ilabel_set = incoming_non_eps_label_per_state[state];
55
+ for (auto it = ilabel_set.begin(); it != ilabel_set.end(); it++) {
56
+ self_loop_arc.ilabel = *it;
57
+ self_loop_arc.olabel = 0;
58
+ self_loop_arc.nextstate = state;
59
+ fst->AddArc(state, self_loop_arc);
60
+ num_arc_added++;
61
+ }
62
+ }
63
+ }
64
+ return num_arc_added;
65
+ }
66
+
67
+ void print_usage() {
68
+ std::cout << "add-self-loop-simple usage:\n"
69
+ "\tadd-self-loop-simple <in-fst> <out-fst> \n";
70
+ }
71
+ } // namespace
72
+
73
+ int main(int argc, char** argv) {
74
+ if (argc != 3) {
75
+ print_usage();
76
+ exit(1);
77
+ }
78
+
79
+ auto input = argv[1];
80
+ auto output = argv[2];
81
+
82
+ auto fst = fst::ReadFstKaldi(input);
83
+ auto num_states = fst->NumStates();
84
+ KALDI_LOG << "Loading FST from " << input << " with " << num_states
85
+ << " states." << std::endl;
86
+
87
+ int32 num_arc_added = AddSelfLoopsSimple(fst);
88
+ KALDI_LOG << "Adding " << num_arc_added << " self-loop arcs " << std::endl;
89
+
90
+ fst::WriteFstKaldi(*fst, std::string(output));
91
+ KALDI_LOG << "Writing FST to " << output << std::endl;
92
+
93
+ delete fst;
94
+ }
fairseq/examples/speech_recognition/kaldi/config/kaldi_initializer.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ data_dir: ???
4
+ fst_dir: ???
5
+ in_labels: ???
6
+ kaldi_root: ???
7
+ lm_arpa: ???
8
+ blank_symbol: <s>
fairseq/examples/speech_recognition/kaldi/kaldi_decoder.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ from concurrent.futures import ThreadPoolExecutor
9
+ import logging
10
+ from omegaconf import MISSING
11
+ import os
12
+ import torch
13
+ from typing import Optional
14
+ import warnings
15
+
16
+
17
+ from dataclasses import dataclass
18
+ from fairseq.dataclass import FairseqDataclass
19
+ from .kaldi_initializer import KaldiInitializerConfig, initalize_kaldi
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @dataclass
26
+ class KaldiDecoderConfig(FairseqDataclass):
27
+ hlg_graph_path: Optional[str] = None
28
+ output_dict: str = MISSING
29
+
30
+ kaldi_initializer_config: Optional[KaldiInitializerConfig] = None
31
+
32
+ acoustic_scale: float = 0.5
33
+ max_active: int = 10000
34
+ beam_delta: float = 0.5
35
+ hash_ratio: float = 2.0
36
+
37
+ is_lattice: bool = False
38
+ lattice_beam: float = 10.0
39
+ prune_interval: int = 25
40
+ determinize_lattice: bool = True
41
+ prune_scale: float = 0.1
42
+ max_mem: int = 0
43
+ phone_determinize: bool = True
44
+ word_determinize: bool = True
45
+ minimize: bool = True
46
+
47
+ num_threads: int = 1
48
+
49
+
50
+ class KaldiDecoder(object):
51
+ def __init__(
52
+ self,
53
+ cfg: KaldiDecoderConfig,
54
+ beam: int,
55
+ nbest: int = 1,
56
+ ):
57
+ try:
58
+ from kaldi.asr import FasterRecognizer, LatticeFasterRecognizer
59
+ from kaldi.base import set_verbose_level
60
+ from kaldi.decoder import (
61
+ FasterDecoder,
62
+ FasterDecoderOptions,
63
+ LatticeFasterDecoder,
64
+ LatticeFasterDecoderOptions,
65
+ )
66
+ from kaldi.lat.functions import DeterminizeLatticePhonePrunedOptions
67
+ from kaldi.fstext import read_fst_kaldi, SymbolTable
68
+ except:
69
+ warnings.warn(
70
+ "pykaldi is required for this functionality. Please install from https://github.com/pykaldi/pykaldi"
71
+ )
72
+
73
+ # set_verbose_level(2)
74
+
75
+ self.acoustic_scale = cfg.acoustic_scale
76
+ self.nbest = nbest
77
+
78
+ if cfg.hlg_graph_path is None:
79
+ assert (
80
+ cfg.kaldi_initializer_config is not None
81
+ ), "Must provide hlg graph path or kaldi initializer config"
82
+ cfg.hlg_graph_path = initalize_kaldi(cfg.kaldi_initializer_config)
83
+
84
+ assert os.path.exists(cfg.hlg_graph_path), cfg.hlg_graph_path
85
+
86
+ if cfg.is_lattice:
87
+ self.dec_cls = LatticeFasterDecoder
88
+ opt_cls = LatticeFasterDecoderOptions
89
+ self.rec_cls = LatticeFasterRecognizer
90
+ else:
91
+ assert self.nbest == 1, "nbest > 1 requires lattice decoder"
92
+ self.dec_cls = FasterDecoder
93
+ opt_cls = FasterDecoderOptions
94
+ self.rec_cls = FasterRecognizer
95
+
96
+ self.decoder_options = opt_cls()
97
+ self.decoder_options.beam = beam
98
+ self.decoder_options.max_active = cfg.max_active
99
+ self.decoder_options.beam_delta = cfg.beam_delta
100
+ self.decoder_options.hash_ratio = cfg.hash_ratio
101
+
102
+ if cfg.is_lattice:
103
+ self.decoder_options.lattice_beam = cfg.lattice_beam
104
+ self.decoder_options.prune_interval = cfg.prune_interval
105
+ self.decoder_options.determinize_lattice = cfg.determinize_lattice
106
+ self.decoder_options.prune_scale = cfg.prune_scale
107
+ det_opts = DeterminizeLatticePhonePrunedOptions()
108
+ det_opts.max_mem = cfg.max_mem
109
+ det_opts.phone_determinize = cfg.phone_determinize
110
+ det_opts.word_determinize = cfg.word_determinize
111
+ det_opts.minimize = cfg.minimize
112
+ self.decoder_options.det_opts = det_opts
113
+
114
+ self.output_symbols = {}
115
+ with open(cfg.output_dict, "r") as f:
116
+ for line in f:
117
+ items = line.rstrip().split()
118
+ assert len(items) == 2
119
+ self.output_symbols[int(items[1])] = items[0]
120
+
121
+ logger.info(f"Loading FST from {cfg.hlg_graph_path}")
122
+ self.fst = read_fst_kaldi(cfg.hlg_graph_path)
123
+ self.symbol_table = SymbolTable.read_text(cfg.output_dict)
124
+
125
+ self.executor = ThreadPoolExecutor(max_workers=cfg.num_threads)
126
+
127
+ def generate(self, models, sample, **unused):
128
+ """Generate a batch of inferences."""
129
+ # model.forward normally channels prev_output_tokens into the decoder
130
+ # separately, but SequenceGenerator directly calls model.encoder
131
+ encoder_input = {
132
+ k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
133
+ }
134
+ emissions, padding = self.get_emissions(models, encoder_input)
135
+ return self.decode(emissions, padding)
136
+
137
+ def get_emissions(self, models, encoder_input):
138
+ """Run encoder and normalize emissions"""
139
+ model = models[0]
140
+
141
+ all_encoder_out = [m(**encoder_input) for m in models]
142
+
143
+ if len(all_encoder_out) > 1:
144
+
145
+ if "encoder_out" in all_encoder_out[0]:
146
+ encoder_out = {
147
+ "encoder_out": sum(e["encoder_out"] for e in all_encoder_out)
148
+ / len(all_encoder_out),
149
+ "encoder_padding_mask": all_encoder_out[0]["encoder_padding_mask"],
150
+ }
151
+ padding = encoder_out["encoder_padding_mask"]
152
+ else:
153
+ encoder_out = {
154
+ "logits": sum(e["logits"] for e in all_encoder_out)
155
+ / len(all_encoder_out),
156
+ "padding_mask": all_encoder_out[0]["padding_mask"],
157
+ }
158
+ padding = encoder_out["padding_mask"]
159
+ else:
160
+ encoder_out = all_encoder_out[0]
161
+ padding = (
162
+ encoder_out["padding_mask"]
163
+ if "padding_mask" in encoder_out
164
+ else encoder_out["encoder_padding_mask"]
165
+ )
166
+
167
+ if hasattr(model, "get_logits"):
168
+ emissions = model.get_logits(encoder_out, normalize=True)
169
+ else:
170
+ emissions = model.get_normalized_probs(encoder_out, log_probs=True)
171
+
172
+ return (
173
+ emissions.cpu().float().transpose(0, 1),
174
+ padding.cpu() if padding is not None and padding.any() else None,
175
+ )
176
+
177
+ def decode_one(self, logits, padding):
178
+ from kaldi.matrix import Matrix
179
+
180
+ decoder = self.dec_cls(self.fst, self.decoder_options)
181
+ asr = self.rec_cls(
182
+ decoder, self.symbol_table, acoustic_scale=self.acoustic_scale
183
+ )
184
+
185
+ if padding is not None:
186
+ logits = logits[~padding]
187
+
188
+ mat = Matrix(logits.numpy())
189
+
190
+ out = asr.decode(mat)
191
+
192
+ if self.nbest > 1:
193
+ from kaldi.fstext import shortestpath
194
+ from kaldi.fstext.utils import (
195
+ convert_compact_lattice_to_lattice,
196
+ convert_lattice_to_std,
197
+ convert_nbest_to_list,
198
+ get_linear_symbol_sequence,
199
+ )
200
+
201
+ lat = out["lattice"]
202
+
203
+ sp = shortestpath(lat, nshortest=self.nbest)
204
+
205
+ sp = convert_compact_lattice_to_lattice(sp)
206
+ sp = convert_lattice_to_std(sp)
207
+ seq = convert_nbest_to_list(sp)
208
+
209
+ results = []
210
+ for s in seq:
211
+ _, o, w = get_linear_symbol_sequence(s)
212
+ words = list(self.output_symbols[z] for z in o)
213
+ results.append(
214
+ {
215
+ "tokens": words,
216
+ "words": words,
217
+ "score": w.value,
218
+ "emissions": logits,
219
+ }
220
+ )
221
+ return results
222
+ else:
223
+ words = out["text"].split()
224
+ return [
225
+ {
226
+ "tokens": words,
227
+ "words": words,
228
+ "score": out["likelihood"],
229
+ "emissions": logits,
230
+ }
231
+ ]
232
+
233
+ def decode(self, emissions, padding):
234
+ if padding is None:
235
+ padding = [None] * len(emissions)
236
+
237
+ ret = list(
238
+ map(
239
+ lambda e, p: self.executor.submit(self.decode_one, e, p),
240
+ emissions,
241
+ padding,
242
+ )
243
+ )
244
+ return ret
fairseq/examples/speech_recognition/kaldi/kaldi_initializer.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ from dataclasses import dataclass
9
+ import hydra
10
+ from hydra.core.config_store import ConfigStore
11
+ import logging
12
+ from omegaconf import MISSING, OmegaConf
13
+ import os
14
+ import os.path as osp
15
+ from pathlib import Path
16
+ import subprocess
17
+ from typing import Optional
18
+
19
+ from fairseq.data.dictionary import Dictionary
20
+ from fairseq.dataclass import FairseqDataclass
21
+
22
+ script_dir = Path(__file__).resolve().parent
23
+ config_path = script_dir / "config"
24
+
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ @dataclass
30
+ class KaldiInitializerConfig(FairseqDataclass):
31
+ data_dir: str = MISSING
32
+ fst_dir: Optional[str] = None
33
+ in_labels: str = MISSING
34
+ out_labels: Optional[str] = None
35
+ wav2letter_lexicon: Optional[str] = None
36
+ lm_arpa: str = MISSING
37
+ kaldi_root: str = MISSING
38
+ blank_symbol: str = "<s>"
39
+ silence_symbol: Optional[str] = None
40
+
41
+
42
+ def create_units(fst_dir: Path, in_labels: str, vocab: Dictionary) -> Path:
43
+ in_units_file = fst_dir / f"kaldi_dict.{in_labels}.txt"
44
+ if not in_units_file.exists():
45
+
46
+ logger.info(f"Creating {in_units_file}")
47
+
48
+ with open(in_units_file, "w") as f:
49
+ print("<eps> 0", file=f)
50
+ i = 1
51
+ for symb in vocab.symbols[vocab.nspecial :]:
52
+ if not symb.startswith("madeupword"):
53
+ print(f"{symb} {i}", file=f)
54
+ i += 1
55
+ return in_units_file
56
+
57
+
58
+ def create_lexicon(
59
+ cfg: KaldiInitializerConfig,
60
+ fst_dir: Path,
61
+ unique_label: str,
62
+ in_units_file: Path,
63
+ out_words_file: Path,
64
+ ) -> (Path, Path):
65
+
66
+ disambig_in_units_file = fst_dir / f"kaldi_dict.{cfg.in_labels}_disambig.txt"
67
+ lexicon_file = fst_dir / f"kaldi_lexicon.{unique_label}.txt"
68
+ disambig_lexicon_file = fst_dir / f"kaldi_lexicon.{unique_label}_disambig.txt"
69
+ if (
70
+ not lexicon_file.exists()
71
+ or not disambig_lexicon_file.exists()
72
+ or not disambig_in_units_file.exists()
73
+ ):
74
+ logger.info(f"Creating {lexicon_file} (in units file: {in_units_file})")
75
+
76
+ assert cfg.wav2letter_lexicon is not None or cfg.in_labels == cfg.out_labels
77
+
78
+ if cfg.wav2letter_lexicon is not None:
79
+ lm_words = set()
80
+ with open(out_words_file, "r") as lm_dict_f:
81
+ for line in lm_dict_f:
82
+ lm_words.add(line.split()[0])
83
+
84
+ num_skipped = 0
85
+ total = 0
86
+ with open(cfg.wav2letter_lexicon, "r") as w2l_lex_f, open(
87
+ lexicon_file, "w"
88
+ ) as out_f:
89
+ for line in w2l_lex_f:
90
+ items = line.rstrip().split("\t")
91
+ assert len(items) == 2, items
92
+ if items[0] in lm_words:
93
+ print(items[0], items[1], file=out_f)
94
+ else:
95
+ num_skipped += 1
96
+ logger.debug(
97
+ f"Skipping word {items[0]} as it was not found in LM"
98
+ )
99
+ total += 1
100
+ if num_skipped > 0:
101
+ logger.warning(
102
+ f"Skipped {num_skipped} out of {total} words as they were not found in LM"
103
+ )
104
+ else:
105
+ with open(in_units_file, "r") as in_f, open(lexicon_file, "w") as out_f:
106
+ for line in in_f:
107
+ symb = line.split()[0]
108
+ if symb != "<eps>" and symb != "<ctc_blank>" and symb != "<SIL>":
109
+ print(symb, symb, file=out_f)
110
+
111
+ lex_disambig_path = (
112
+ Path(cfg.kaldi_root) / "egs/wsj/s5/utils/add_lex_disambig.pl"
113
+ )
114
+ res = subprocess.run(
115
+ [lex_disambig_path, lexicon_file, disambig_lexicon_file],
116
+ check=True,
117
+ capture_output=True,
118
+ )
119
+ ndisambig = int(res.stdout)
120
+ disamib_path = Path(cfg.kaldi_root) / "egs/wsj/s5/utils/add_disambig.pl"
121
+ res = subprocess.run(
122
+ [disamib_path, "--include-zero", in_units_file, str(ndisambig)],
123
+ check=True,
124
+ capture_output=True,
125
+ )
126
+ with open(disambig_in_units_file, "wb") as f:
127
+ f.write(res.stdout)
128
+
129
+ return disambig_lexicon_file, disambig_in_units_file
130
+
131
+
132
+ def create_G(
133
+ kaldi_root: Path, fst_dir: Path, lm_arpa: Path, arpa_base: str
134
+ ) -> (Path, Path):
135
+
136
+ out_words_file = fst_dir / f"kaldi_dict.{arpa_base}.txt"
137
+ grammar_graph = fst_dir / f"G_{arpa_base}.fst"
138
+ if not grammar_graph.exists() or not out_words_file.exists():
139
+ logger.info(f"Creating {grammar_graph}")
140
+ arpa2fst = kaldi_root / "src/lmbin/arpa2fst"
141
+ subprocess.run(
142
+ [
143
+ arpa2fst,
144
+ "--disambig-symbol=#0",
145
+ f"--write-symbol-table={out_words_file}",
146
+ lm_arpa,
147
+ grammar_graph,
148
+ ],
149
+ check=True,
150
+ )
151
+ return grammar_graph, out_words_file
152
+
153
+
154
+ def create_L(
155
+ kaldi_root: Path,
156
+ fst_dir: Path,
157
+ unique_label: str,
158
+ lexicon_file: Path,
159
+ in_units_file: Path,
160
+ out_words_file: Path,
161
+ ) -> Path:
162
+ lexicon_graph = fst_dir / f"L.{unique_label}.fst"
163
+
164
+ if not lexicon_graph.exists():
165
+ logger.info(f"Creating {lexicon_graph} (in units: {in_units_file})")
166
+ make_lex = kaldi_root / "egs/wsj/s5/utils/make_lexicon_fst.pl"
167
+ fstcompile = kaldi_root / "tools/openfst-1.6.7/bin/fstcompile"
168
+ fstaddselfloops = kaldi_root / "src/fstbin/fstaddselfloops"
169
+ fstarcsort = kaldi_root / "tools/openfst-1.6.7/bin/fstarcsort"
170
+
171
+ def write_disambig_symbol(file):
172
+ with open(file, "r") as f:
173
+ for line in f:
174
+ items = line.rstrip().split()
175
+ if items[0] == "#0":
176
+ out_path = str(file) + "_disamig"
177
+ with open(out_path, "w") as out_f:
178
+ print(items[1], file=out_f)
179
+ return out_path
180
+
181
+ return None
182
+
183
+ in_disambig_sym = write_disambig_symbol(in_units_file)
184
+ assert in_disambig_sym is not None
185
+ out_disambig_sym = write_disambig_symbol(out_words_file)
186
+ assert out_disambig_sym is not None
187
+
188
+ try:
189
+ with open(lexicon_graph, "wb") as out_f:
190
+ res = subprocess.run(
191
+ [make_lex, lexicon_file], capture_output=True, check=True
192
+ )
193
+ assert len(res.stderr) == 0, res.stderr.decode("utf-8")
194
+ res = subprocess.run(
195
+ [
196
+ fstcompile,
197
+ f"--isymbols={in_units_file}",
198
+ f"--osymbols={out_words_file}",
199
+ "--keep_isymbols=false",
200
+ "--keep_osymbols=false",
201
+ ],
202
+ input=res.stdout,
203
+ capture_output=True,
204
+ )
205
+ assert len(res.stderr) == 0, res.stderr.decode("utf-8")
206
+ res = subprocess.run(
207
+ [fstaddselfloops, in_disambig_sym, out_disambig_sym],
208
+ input=res.stdout,
209
+ capture_output=True,
210
+ check=True,
211
+ )
212
+ res = subprocess.run(
213
+ [fstarcsort, "--sort_type=olabel"],
214
+ input=res.stdout,
215
+ capture_output=True,
216
+ check=True,
217
+ )
218
+ out_f.write(res.stdout)
219
+ except subprocess.CalledProcessError as e:
220
+ logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}")
221
+ os.remove(lexicon_graph)
222
+ raise
223
+ except AssertionError:
224
+ os.remove(lexicon_graph)
225
+ raise
226
+
227
+ return lexicon_graph
228
+
229
+
230
+ def create_LG(
231
+ kaldi_root: Path,
232
+ fst_dir: Path,
233
+ unique_label: str,
234
+ lexicon_graph: Path,
235
+ grammar_graph: Path,
236
+ ) -> Path:
237
+ lg_graph = fst_dir / f"LG.{unique_label}.fst"
238
+
239
+ if not lg_graph.exists():
240
+ logger.info(f"Creating {lg_graph}")
241
+
242
+ fsttablecompose = kaldi_root / "src/fstbin/fsttablecompose"
243
+ fstdeterminizestar = kaldi_root / "src/fstbin/fstdeterminizestar"
244
+ fstminimizeencoded = kaldi_root / "src/fstbin/fstminimizeencoded"
245
+ fstpushspecial = kaldi_root / "src/fstbin/fstpushspecial"
246
+ fstarcsort = kaldi_root / "tools/openfst-1.6.7/bin/fstarcsort"
247
+
248
+ try:
249
+ with open(lg_graph, "wb") as out_f:
250
+ res = subprocess.run(
251
+ [fsttablecompose, lexicon_graph, grammar_graph],
252
+ capture_output=True,
253
+ check=True,
254
+ )
255
+ res = subprocess.run(
256
+ [
257
+ fstdeterminizestar,
258
+ "--use-log=true",
259
+ ],
260
+ input=res.stdout,
261
+ capture_output=True,
262
+ )
263
+ res = subprocess.run(
264
+ [fstminimizeencoded],
265
+ input=res.stdout,
266
+ capture_output=True,
267
+ check=True,
268
+ )
269
+ res = subprocess.run(
270
+ [fstpushspecial],
271
+ input=res.stdout,
272
+ capture_output=True,
273
+ check=True,
274
+ )
275
+ res = subprocess.run(
276
+ [fstarcsort, "--sort_type=ilabel"],
277
+ input=res.stdout,
278
+ capture_output=True,
279
+ check=True,
280
+ )
281
+ out_f.write(res.stdout)
282
+ except subprocess.CalledProcessError as e:
283
+ logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}")
284
+ os.remove(lg_graph)
285
+ raise
286
+
287
+ return lg_graph
288
+
289
+
290
+ def create_H(
291
+ kaldi_root: Path,
292
+ fst_dir: Path,
293
+ disambig_out_units_file: Path,
294
+ in_labels: str,
295
+ vocab: Dictionary,
296
+ blk_sym: str,
297
+ silence_symbol: Optional[str],
298
+ ) -> (Path, Path, Path):
299
+ h_graph = (
300
+ fst_dir / f"H.{in_labels}{'_' + silence_symbol if silence_symbol else ''}.fst"
301
+ )
302
+ h_out_units_file = fst_dir / f"kaldi_dict.h_out.{in_labels}.txt"
303
+ disambig_in_units_file_int = Path(str(h_graph) + "isym_disambig.int")
304
+ disambig_out_units_file_int = Path(str(disambig_out_units_file) + ".int")
305
+ if (
306
+ not h_graph.exists()
307
+ or not h_out_units_file.exists()
308
+ or not disambig_in_units_file_int.exists()
309
+ ):
310
+ logger.info(f"Creating {h_graph}")
311
+ eps_sym = "<eps>"
312
+
313
+ num_disambig = 0
314
+ osymbols = []
315
+
316
+ with open(disambig_out_units_file, "r") as f, open(
317
+ disambig_out_units_file_int, "w"
318
+ ) as out_f:
319
+ for line in f:
320
+ symb, id = line.rstrip().split()
321
+ if line.startswith("#"):
322
+ num_disambig += 1
323
+ print(id, file=out_f)
324
+ else:
325
+ if len(osymbols) == 0:
326
+ assert symb == eps_sym, symb
327
+ osymbols.append((symb, id))
328
+
329
+ i_idx = 0
330
+ isymbols = [(eps_sym, 0)]
331
+
332
+ imap = {}
333
+
334
+ for i, s in enumerate(vocab.symbols):
335
+ i_idx += 1
336
+ isymbols.append((s, i_idx))
337
+ imap[s] = i_idx
338
+
339
+ fst_str = []
340
+
341
+ node_idx = 0
342
+ root_node = node_idx
343
+
344
+ special_symbols = [blk_sym]
345
+ if silence_symbol is not None:
346
+ special_symbols.append(silence_symbol)
347
+
348
+ for ss in special_symbols:
349
+ fst_str.append("{} {} {} {}".format(root_node, root_node, ss, eps_sym))
350
+
351
+ for symbol, _ in osymbols:
352
+ if symbol == eps_sym or symbol.startswith("#"):
353
+ continue
354
+
355
+ node_idx += 1
356
+ # 1. from root to emitting state
357
+ fst_str.append("{} {} {} {}".format(root_node, node_idx, symbol, symbol))
358
+ # 2. from emitting state back to root
359
+ fst_str.append("{} {} {} {}".format(node_idx, root_node, eps_sym, eps_sym))
360
+ # 3. from emitting state to optional blank state
361
+ pre_node = node_idx
362
+ node_idx += 1
363
+ for ss in special_symbols:
364
+ fst_str.append("{} {} {} {}".format(pre_node, node_idx, ss, eps_sym))
365
+ # 4. from blank state back to root
366
+ fst_str.append("{} {} {} {}".format(node_idx, root_node, eps_sym, eps_sym))
367
+
368
+ fst_str.append("{}".format(root_node))
369
+
370
+ fst_str = "\n".join(fst_str)
371
+ h_str = str(h_graph)
372
+ isym_file = h_str + ".isym"
373
+
374
+ with open(isym_file, "w") as f:
375
+ for sym, id in isymbols:
376
+ f.write("{} {}\n".format(sym, id))
377
+
378
+ with open(h_out_units_file, "w") as f:
379
+ for sym, id in osymbols:
380
+ f.write("{} {}\n".format(sym, id))
381
+
382
+ with open(disambig_in_units_file_int, "w") as f:
383
+ disam_sym_id = len(isymbols)
384
+ for _ in range(num_disambig):
385
+ f.write("{}\n".format(disam_sym_id))
386
+ disam_sym_id += 1
387
+
388
+ fstcompile = kaldi_root / "tools/openfst-1.6.7/bin/fstcompile"
389
+ fstaddselfloops = kaldi_root / "src/fstbin/fstaddselfloops"
390
+ fstarcsort = kaldi_root / "tools/openfst-1.6.7/bin/fstarcsort"
391
+
392
+ try:
393
+ with open(h_graph, "wb") as out_f:
394
+ res = subprocess.run(
395
+ [
396
+ fstcompile,
397
+ f"--isymbols={isym_file}",
398
+ f"--osymbols={h_out_units_file}",
399
+ "--keep_isymbols=false",
400
+ "--keep_osymbols=false",
401
+ ],
402
+ input=str.encode(fst_str),
403
+ capture_output=True,
404
+ check=True,
405
+ )
406
+ res = subprocess.run(
407
+ [
408
+ fstaddselfloops,
409
+ disambig_in_units_file_int,
410
+ disambig_out_units_file_int,
411
+ ],
412
+ input=res.stdout,
413
+ capture_output=True,
414
+ check=True,
415
+ )
416
+ res = subprocess.run(
417
+ [fstarcsort, "--sort_type=olabel"],
418
+ input=res.stdout,
419
+ capture_output=True,
420
+ check=True,
421
+ )
422
+ out_f.write(res.stdout)
423
+ except subprocess.CalledProcessError as e:
424
+ logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}")
425
+ os.remove(h_graph)
426
+ raise
427
+ return h_graph, h_out_units_file, disambig_in_units_file_int
428
+
429
+
430
+ def create_HLGa(
431
+ kaldi_root: Path,
432
+ fst_dir: Path,
433
+ unique_label: str,
434
+ h_graph: Path,
435
+ lg_graph: Path,
436
+ disambig_in_words_file_int: Path,
437
+ ) -> Path:
438
+ hlga_graph = fst_dir / f"HLGa.{unique_label}.fst"
439
+
440
+ if not hlga_graph.exists():
441
+ logger.info(f"Creating {hlga_graph}")
442
+
443
+ fsttablecompose = kaldi_root / "src/fstbin/fsttablecompose"
444
+ fstdeterminizestar = kaldi_root / "src/fstbin/fstdeterminizestar"
445
+ fstrmsymbols = kaldi_root / "src/fstbin/fstrmsymbols"
446
+ fstrmepslocal = kaldi_root / "src/fstbin/fstrmepslocal"
447
+ fstminimizeencoded = kaldi_root / "src/fstbin/fstminimizeencoded"
448
+
449
+ try:
450
+ with open(hlga_graph, "wb") as out_f:
451
+ res = subprocess.run(
452
+ [
453
+ fsttablecompose,
454
+ h_graph,
455
+ lg_graph,
456
+ ],
457
+ capture_output=True,
458
+ check=True,
459
+ )
460
+ res = subprocess.run(
461
+ [fstdeterminizestar, "--use-log=true"],
462
+ input=res.stdout,
463
+ capture_output=True,
464
+ check=True,
465
+ )
466
+ res = subprocess.run(
467
+ [fstrmsymbols, disambig_in_words_file_int],
468
+ input=res.stdout,
469
+ capture_output=True,
470
+ check=True,
471
+ )
472
+ res = subprocess.run(
473
+ [fstrmepslocal],
474
+ input=res.stdout,
475
+ capture_output=True,
476
+ check=True,
477
+ )
478
+ res = subprocess.run(
479
+ [fstminimizeencoded],
480
+ input=res.stdout,
481
+ capture_output=True,
482
+ check=True,
483
+ )
484
+ out_f.write(res.stdout)
485
+ except subprocess.CalledProcessError as e:
486
+ logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}")
487
+ os.remove(hlga_graph)
488
+ raise
489
+
490
+ return hlga_graph
491
+
492
+
493
+ def create_HLa(
494
+ kaldi_root: Path,
495
+ fst_dir: Path,
496
+ unique_label: str,
497
+ h_graph: Path,
498
+ l_graph: Path,
499
+ disambig_in_words_file_int: Path,
500
+ ) -> Path:
501
+ hla_graph = fst_dir / f"HLa.{unique_label}.fst"
502
+
503
+ if not hla_graph.exists():
504
+ logger.info(f"Creating {hla_graph}")
505
+
506
+ fsttablecompose = kaldi_root / "src/fstbin/fsttablecompose"
507
+ fstdeterminizestar = kaldi_root / "src/fstbin/fstdeterminizestar"
508
+ fstrmsymbols = kaldi_root / "src/fstbin/fstrmsymbols"
509
+ fstrmepslocal = kaldi_root / "src/fstbin/fstrmepslocal"
510
+ fstminimizeencoded = kaldi_root / "src/fstbin/fstminimizeencoded"
511
+
512
+ try:
513
+ with open(hla_graph, "wb") as out_f:
514
+ res = subprocess.run(
515
+ [
516
+ fsttablecompose,
517
+ h_graph,
518
+ l_graph,
519
+ ],
520
+ capture_output=True,
521
+ check=True,
522
+ )
523
+ res = subprocess.run(
524
+ [fstdeterminizestar, "--use-log=true"],
525
+ input=res.stdout,
526
+ capture_output=True,
527
+ check=True,
528
+ )
529
+ res = subprocess.run(
530
+ [fstrmsymbols, disambig_in_words_file_int],
531
+ input=res.stdout,
532
+ capture_output=True,
533
+ check=True,
534
+ )
535
+ res = subprocess.run(
536
+ [fstrmepslocal],
537
+ input=res.stdout,
538
+ capture_output=True,
539
+ check=True,
540
+ )
541
+ res = subprocess.run(
542
+ [fstminimizeencoded],
543
+ input=res.stdout,
544
+ capture_output=True,
545
+ check=True,
546
+ )
547
+ out_f.write(res.stdout)
548
+ except subprocess.CalledProcessError as e:
549
+ logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}")
550
+ os.remove(hla_graph)
551
+ raise
552
+
553
+ return hla_graph
554
+
555
+
556
+ def create_HLG(
557
+ kaldi_root: Path,
558
+ fst_dir: Path,
559
+ unique_label: str,
560
+ hlga_graph: Path,
561
+ prefix: str = "HLG",
562
+ ) -> Path:
563
+ hlg_graph = fst_dir / f"{prefix}.{unique_label}.fst"
564
+
565
+ if not hlg_graph.exists():
566
+ logger.info(f"Creating {hlg_graph}")
567
+
568
+ add_self_loop = script_dir / "add-self-loop-simple"
569
+ kaldi_src = kaldi_root / "src"
570
+ kaldi_lib = kaldi_src / "lib"
571
+
572
+ try:
573
+ if not add_self_loop.exists():
574
+ fst_include = kaldi_root / "tools/openfst-1.6.7/include"
575
+ add_self_loop_src = script_dir / "add-self-loop-simple.cc"
576
+
577
+ subprocess.run(
578
+ [
579
+ "c++",
580
+ f"-I{kaldi_src}",
581
+ f"-I{fst_include}",
582
+ f"-L{kaldi_lib}",
583
+ add_self_loop_src,
584
+ "-lkaldi-base",
585
+ "-lkaldi-fstext",
586
+ "-o",
587
+ add_self_loop,
588
+ ],
589
+ check=True,
590
+ )
591
+
592
+ my_env = os.environ.copy()
593
+ my_env["LD_LIBRARY_PATH"] = f"{kaldi_lib}:{my_env['LD_LIBRARY_PATH']}"
594
+
595
+ subprocess.run(
596
+ [
597
+ add_self_loop,
598
+ hlga_graph,
599
+ hlg_graph,
600
+ ],
601
+ check=True,
602
+ capture_output=True,
603
+ env=my_env,
604
+ )
605
+ except subprocess.CalledProcessError as e:
606
+ logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}")
607
+ raise
608
+
609
+ return hlg_graph
610
+
611
+
612
+ def initalize_kaldi(cfg: KaldiInitializerConfig) -> Path:
613
+ if cfg.fst_dir is None:
614
+ cfg.fst_dir = osp.join(cfg.data_dir, "kaldi")
615
+ if cfg.out_labels is None:
616
+ cfg.out_labels = cfg.in_labels
617
+
618
+ kaldi_root = Path(cfg.kaldi_root)
619
+ data_dir = Path(cfg.data_dir)
620
+ fst_dir = Path(cfg.fst_dir)
621
+ fst_dir.mkdir(parents=True, exist_ok=True)
622
+
623
+ arpa_base = osp.splitext(osp.basename(cfg.lm_arpa))[0]
624
+ unique_label = f"{cfg.in_labels}.{arpa_base}"
625
+
626
+ with open(data_dir / f"dict.{cfg.in_labels}.txt", "r") as f:
627
+ vocab = Dictionary.load(f)
628
+
629
+ in_units_file = create_units(fst_dir, cfg.in_labels, vocab)
630
+
631
+ grammar_graph, out_words_file = create_G(
632
+ kaldi_root, fst_dir, Path(cfg.lm_arpa), arpa_base
633
+ )
634
+
635
+ disambig_lexicon_file, disambig_L_in_units_file = create_lexicon(
636
+ cfg, fst_dir, unique_label, in_units_file, out_words_file
637
+ )
638
+
639
+ h_graph, h_out_units_file, disambig_in_units_file_int = create_H(
640
+ kaldi_root,
641
+ fst_dir,
642
+ disambig_L_in_units_file,
643
+ cfg.in_labels,
644
+ vocab,
645
+ cfg.blank_symbol,
646
+ cfg.silence_symbol,
647
+ )
648
+ lexicon_graph = create_L(
649
+ kaldi_root,
650
+ fst_dir,
651
+ unique_label,
652
+ disambig_lexicon_file,
653
+ disambig_L_in_units_file,
654
+ out_words_file,
655
+ )
656
+ lg_graph = create_LG(
657
+ kaldi_root, fst_dir, unique_label, lexicon_graph, grammar_graph
658
+ )
659
+ hlga_graph = create_HLGa(
660
+ kaldi_root, fst_dir, unique_label, h_graph, lg_graph, disambig_in_units_file_int
661
+ )
662
+ hlg_graph = create_HLG(kaldi_root, fst_dir, unique_label, hlga_graph)
663
+
664
+ # for debugging
665
+ # hla_graph = create_HLa(kaldi_root, fst_dir, unique_label, h_graph, lexicon_graph, disambig_in_units_file_int)
666
+ # hl_graph = create_HLG(kaldi_root, fst_dir, unique_label, hla_graph, prefix="HL_looped")
667
+ # create_HLG(kaldi_root, fst_dir, "phnc", h_graph, prefix="H_looped")
668
+
669
+ return hlg_graph
670
+
671
+
672
+ @hydra.main(config_path=config_path, config_name="kaldi_initializer")
673
+ def cli_main(cfg: KaldiInitializerConfig) -> None:
674
+ container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)
675
+ cfg = OmegaConf.create(container)
676
+ OmegaConf.set_struct(cfg, True)
677
+ initalize_kaldi(cfg)
678
+
679
+
680
+ if __name__ == "__main__":
681
+
682
+ logging.root.setLevel(logging.INFO)
683
+ logging.basicConfig(level=logging.INFO)
684
+
685
+ try:
686
+ from hydra._internal.utils import (
687
+ get_args,
688
+ ) # pylint: disable=import-outside-toplevel
689
+
690
+ cfg_name = get_args().config_name or "kaldi_initializer"
691
+ except ImportError:
692
+ logger.warning("Failed to get config name from hydra args")
693
+ cfg_name = "kaldi_initializer"
694
+
695
+ cs = ConfigStore.instance()
696
+ cs.store(name=cfg_name, node=KaldiInitializerConfig)
697
+
698
+ cli_main()
fairseq/examples/speech_recognition/models/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+
4
+
5
+ for file in sorted(os.listdir(os.path.dirname(__file__))):
6
+ if file.endswith(".py") and not file.startswith("_"):
7
+ model_name = file[: file.find(".py")]
8
+ importlib.import_module("examples.speech_recognition.models." + model_name)
fairseq/examples/speech_recognition/models/vggtransformer.py ADDED
@@ -0,0 +1,1020 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import math
8
+ from collections.abc import Iterable
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask
13
+ from fairseq import utils
14
+ from fairseq.models import (
15
+ FairseqEncoder,
16
+ FairseqEncoderDecoderModel,
17
+ FairseqEncoderModel,
18
+ FairseqIncrementalDecoder,
19
+ register_model,
20
+ register_model_architecture,
21
+ )
22
+ from fairseq.modules import (
23
+ LinearizedConvolution,
24
+ TransformerDecoderLayer,
25
+ TransformerEncoderLayer,
26
+ VGGBlock,
27
+ )
28
+
29
+
30
+ @register_model("asr_vggtransformer")
31
+ class VGGTransformerModel(FairseqEncoderDecoderModel):
32
+ """
33
+ Transformers with convolutional context for ASR
34
+ https://arxiv.org/abs/1904.11660
35
+ """
36
+
37
+ def __init__(self, encoder, decoder):
38
+ super().__init__(encoder, decoder)
39
+
40
+ @staticmethod
41
+ def add_args(parser):
42
+ """Add model-specific arguments to the parser."""
43
+ parser.add_argument(
44
+ "--input-feat-per-channel",
45
+ type=int,
46
+ metavar="N",
47
+ help="encoder input dimension per input channel",
48
+ )
49
+ parser.add_argument(
50
+ "--vggblock-enc-config",
51
+ type=str,
52
+ metavar="EXPR",
53
+ help="""
54
+ an array of tuples each containing the configuration of one vggblock:
55
+ [(out_channels,
56
+ conv_kernel_size,
57
+ pooling_kernel_size,
58
+ num_conv_layers,
59
+ use_layer_norm), ...])
60
+ """,
61
+ )
62
+ parser.add_argument(
63
+ "--transformer-enc-config",
64
+ type=str,
65
+ metavar="EXPR",
66
+ help=""""
67
+ a tuple containing the configuration of the encoder transformer layers
68
+ configurations:
69
+ [(input_dim,
70
+ num_heads,
71
+ ffn_dim,
72
+ normalize_before,
73
+ dropout,
74
+ attention_dropout,
75
+ relu_dropout), ...]')
76
+ """,
77
+ )
78
+ parser.add_argument(
79
+ "--enc-output-dim",
80
+ type=int,
81
+ metavar="N",
82
+ help="""
83
+ encoder output dimension, can be None. If specified, projecting the
84
+ transformer output to the specified dimension""",
85
+ )
86
+ parser.add_argument(
87
+ "--in-channels",
88
+ type=int,
89
+ metavar="N",
90
+ help="number of encoder input channels",
91
+ )
92
+ parser.add_argument(
93
+ "--tgt-embed-dim",
94
+ type=int,
95
+ metavar="N",
96
+ help="embedding dimension of the decoder target tokens",
97
+ )
98
+ parser.add_argument(
99
+ "--transformer-dec-config",
100
+ type=str,
101
+ metavar="EXPR",
102
+ help="""
103
+ a tuple containing the configuration of the decoder transformer layers
104
+ configurations:
105
+ [(input_dim,
106
+ num_heads,
107
+ ffn_dim,
108
+ normalize_before,
109
+ dropout,
110
+ attention_dropout,
111
+ relu_dropout), ...]
112
+ """,
113
+ )
114
+ parser.add_argument(
115
+ "--conv-dec-config",
116
+ type=str,
117
+ metavar="EXPR",
118
+ help="""
119
+ an array of tuples for the decoder 1-D convolution config
120
+ [(out_channels, conv_kernel_size, use_layer_norm), ...]""",
121
+ )
122
+
123
+ @classmethod
124
+ def build_encoder(cls, args, task):
125
+ return VGGTransformerEncoder(
126
+ input_feat_per_channel=args.input_feat_per_channel,
127
+ vggblock_config=eval(args.vggblock_enc_config),
128
+ transformer_config=eval(args.transformer_enc_config),
129
+ encoder_output_dim=args.enc_output_dim,
130
+ in_channels=args.in_channels,
131
+ )
132
+
133
+ @classmethod
134
+ def build_decoder(cls, args, task):
135
+ return TransformerDecoder(
136
+ dictionary=task.target_dictionary,
137
+ embed_dim=args.tgt_embed_dim,
138
+ transformer_config=eval(args.transformer_dec_config),
139
+ conv_config=eval(args.conv_dec_config),
140
+ encoder_output_dim=args.enc_output_dim,
141
+ )
142
+
143
+ @classmethod
144
+ def build_model(cls, args, task):
145
+ """Build a new model instance."""
146
+ # make sure that all args are properly defaulted
147
+ # (in case there are any new ones)
148
+ base_architecture(args)
149
+
150
+ encoder = cls.build_encoder(args, task)
151
+ decoder = cls.build_decoder(args, task)
152
+ return cls(encoder, decoder)
153
+
154
+ def get_normalized_probs(self, net_output, log_probs, sample=None):
155
+ # net_output['encoder_out'] is a (B, T, D) tensor
156
+ lprobs = super().get_normalized_probs(net_output, log_probs, sample)
157
+ lprobs.batch_first = True
158
+ return lprobs
159
+
160
+
161
+ DEFAULT_ENC_VGGBLOCK_CONFIG = ((32, 3, 2, 2, False),) * 2
162
+ DEFAULT_ENC_TRANSFORMER_CONFIG = ((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2
163
+ # 256: embedding dimension
164
+ # 4: number of heads
165
+ # 1024: FFN
166
+ # True: apply layerNorm before (dropout + resiaul) instead of after
167
+ # 0.2 (dropout): dropout after MultiheadAttention and second FC
168
+ # 0.2 (attention_dropout): dropout in MultiheadAttention
169
+ # 0.2 (relu_dropout): dropout after ReLu
170
+ DEFAULT_DEC_TRANSFORMER_CONFIG = ((256, 2, 1024, True, 0.2, 0.2, 0.2),) * 2
171
+ DEFAULT_DEC_CONV_CONFIG = ((256, 3, True),) * 2
172
+
173
+
174
+ # TODO: repace transformer encoder config from one liner
175
+ # to explicit args to get rid of this transformation
176
+ def prepare_transformer_encoder_params(
177
+ input_dim,
178
+ num_heads,
179
+ ffn_dim,
180
+ normalize_before,
181
+ dropout,
182
+ attention_dropout,
183
+ relu_dropout,
184
+ ):
185
+ args = argparse.Namespace()
186
+ args.encoder_embed_dim = input_dim
187
+ args.encoder_attention_heads = num_heads
188
+ args.attention_dropout = attention_dropout
189
+ args.dropout = dropout
190
+ args.activation_dropout = relu_dropout
191
+ args.encoder_normalize_before = normalize_before
192
+ args.encoder_ffn_embed_dim = ffn_dim
193
+ return args
194
+
195
+
196
+ def prepare_transformer_decoder_params(
197
+ input_dim,
198
+ num_heads,
199
+ ffn_dim,
200
+ normalize_before,
201
+ dropout,
202
+ attention_dropout,
203
+ relu_dropout,
204
+ ):
205
+ args = argparse.Namespace()
206
+ args.encoder_embed_dim = None
207
+ args.decoder_embed_dim = input_dim
208
+ args.decoder_attention_heads = num_heads
209
+ args.attention_dropout = attention_dropout
210
+ args.dropout = dropout
211
+ args.activation_dropout = relu_dropout
212
+ args.decoder_normalize_before = normalize_before
213
+ args.decoder_ffn_embed_dim = ffn_dim
214
+ return args
215
+
216
+
217
+ class VGGTransformerEncoder(FairseqEncoder):
218
+ """VGG + Transformer encoder"""
219
+
220
+ def __init__(
221
+ self,
222
+ input_feat_per_channel,
223
+ vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG,
224
+ transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG,
225
+ encoder_output_dim=512,
226
+ in_channels=1,
227
+ transformer_context=None,
228
+ transformer_sampling=None,
229
+ ):
230
+ """constructor for VGGTransformerEncoder
231
+
232
+ Args:
233
+ - input_feat_per_channel: feature dim (not including stacked,
234
+ just base feature)
235
+ - in_channel: # input channels (e.g., if stack 8 feature vector
236
+ together, this is 8)
237
+ - vggblock_config: configuration of vggblock, see comments on
238
+ DEFAULT_ENC_VGGBLOCK_CONFIG
239
+ - transformer_config: configuration of transformer layer, see comments
240
+ on DEFAULT_ENC_TRANSFORMER_CONFIG
241
+ - encoder_output_dim: final transformer output embedding dimension
242
+ - transformer_context: (left, right) if set, self-attention will be focused
243
+ on (t-left, t+right)
244
+ - transformer_sampling: an iterable of int, must match with
245
+ len(transformer_config), transformer_sampling[i] indicates sampling
246
+ factor for i-th transformer layer, after multihead att and feedfoward
247
+ part
248
+ """
249
+ super().__init__(None)
250
+
251
+ self.num_vggblocks = 0
252
+ if vggblock_config is not None:
253
+ if not isinstance(vggblock_config, Iterable):
254
+ raise ValueError("vggblock_config is not iterable")
255
+ self.num_vggblocks = len(vggblock_config)
256
+
257
+ self.conv_layers = nn.ModuleList()
258
+ self.in_channels = in_channels
259
+ self.input_dim = input_feat_per_channel
260
+ self.pooling_kernel_sizes = []
261
+
262
+ if vggblock_config is not None:
263
+ for _, config in enumerate(vggblock_config):
264
+ (
265
+ out_channels,
266
+ conv_kernel_size,
267
+ pooling_kernel_size,
268
+ num_conv_layers,
269
+ layer_norm,
270
+ ) = config
271
+ self.conv_layers.append(
272
+ VGGBlock(
273
+ in_channels,
274
+ out_channels,
275
+ conv_kernel_size,
276
+ pooling_kernel_size,
277
+ num_conv_layers,
278
+ input_dim=input_feat_per_channel,
279
+ layer_norm=layer_norm,
280
+ )
281
+ )
282
+ self.pooling_kernel_sizes.append(pooling_kernel_size)
283
+ in_channels = out_channels
284
+ input_feat_per_channel = self.conv_layers[-1].output_dim
285
+
286
+ transformer_input_dim = self.infer_conv_output_dim(
287
+ self.in_channels, self.input_dim
288
+ )
289
+ # transformer_input_dim is the output dimension of VGG part
290
+
291
+ self.validate_transformer_config(transformer_config)
292
+ self.transformer_context = self.parse_transformer_context(transformer_context)
293
+ self.transformer_sampling = self.parse_transformer_sampling(
294
+ transformer_sampling, len(transformer_config)
295
+ )
296
+
297
+ self.transformer_layers = nn.ModuleList()
298
+
299
+ if transformer_input_dim != transformer_config[0][0]:
300
+ self.transformer_layers.append(
301
+ Linear(transformer_input_dim, transformer_config[0][0])
302
+ )
303
+ self.transformer_layers.append(
304
+ TransformerEncoderLayer(
305
+ prepare_transformer_encoder_params(*transformer_config[0])
306
+ )
307
+ )
308
+
309
+ for i in range(1, len(transformer_config)):
310
+ if transformer_config[i - 1][0] != transformer_config[i][0]:
311
+ self.transformer_layers.append(
312
+ Linear(transformer_config[i - 1][0], transformer_config[i][0])
313
+ )
314
+ self.transformer_layers.append(
315
+ TransformerEncoderLayer(
316
+ prepare_transformer_encoder_params(*transformer_config[i])
317
+ )
318
+ )
319
+
320
+ self.encoder_output_dim = encoder_output_dim
321
+ self.transformer_layers.extend(
322
+ [
323
+ Linear(transformer_config[-1][0], encoder_output_dim),
324
+ LayerNorm(encoder_output_dim),
325
+ ]
326
+ )
327
+
328
+ def forward(self, src_tokens, src_lengths, **kwargs):
329
+ """
330
+ src_tokens: padded tensor (B, T, C * feat)
331
+ src_lengths: tensor of original lengths of input utterances (B,)
332
+ """
333
+ bsz, max_seq_len, _ = src_tokens.size()
334
+ x = src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim)
335
+ x = x.transpose(1, 2).contiguous()
336
+ # (B, C, T, feat)
337
+
338
+ for layer_idx in range(len(self.conv_layers)):
339
+ x = self.conv_layers[layer_idx](x)
340
+
341
+ bsz, _, output_seq_len, _ = x.size()
342
+
343
+ # (B, C, T, feat) -> (B, T, C, feat) -> (T, B, C, feat) -> (T, B, C * feat)
344
+ x = x.transpose(1, 2).transpose(0, 1)
345
+ x = x.contiguous().view(output_seq_len, bsz, -1)
346
+
347
+ input_lengths = src_lengths.clone()
348
+ for s in self.pooling_kernel_sizes:
349
+ input_lengths = (input_lengths.float() / s).ceil().long()
350
+
351
+ encoder_padding_mask, _ = lengths_to_encoder_padding_mask(
352
+ input_lengths, batch_first=True
353
+ )
354
+ if not encoder_padding_mask.any():
355
+ encoder_padding_mask = None
356
+
357
+ subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5)
358
+ attn_mask = self.lengths_to_attn_mask(input_lengths, subsampling_factor)
359
+
360
+ transformer_layer_idx = 0
361
+
362
+ for layer_idx in range(len(self.transformer_layers)):
363
+
364
+ if isinstance(self.transformer_layers[layer_idx], TransformerEncoderLayer):
365
+ x = self.transformer_layers[layer_idx](
366
+ x, encoder_padding_mask, attn_mask
367
+ )
368
+
369
+ if self.transformer_sampling[transformer_layer_idx] != 1:
370
+ sampling_factor = self.transformer_sampling[transformer_layer_idx]
371
+ x, encoder_padding_mask, attn_mask = self.slice(
372
+ x, encoder_padding_mask, attn_mask, sampling_factor
373
+ )
374
+
375
+ transformer_layer_idx += 1
376
+
377
+ else:
378
+ x = self.transformer_layers[layer_idx](x)
379
+
380
+ # encoder_padding_maks is a (T x B) tensor, its [t, b] elements indicate
381
+ # whether encoder_output[t, b] is valid or not (valid=0, invalid=1)
382
+
383
+ return {
384
+ "encoder_out": x, # (T, B, C)
385
+ "encoder_padding_mask": encoder_padding_mask.t()
386
+ if encoder_padding_mask is not None
387
+ else None,
388
+ # (B, T) --> (T, B)
389
+ }
390
+
391
+ def infer_conv_output_dim(self, in_channels, input_dim):
392
+ sample_seq_len = 200
393
+ sample_bsz = 10
394
+ x = torch.randn(sample_bsz, in_channels, sample_seq_len, input_dim)
395
+ for i, _ in enumerate(self.conv_layers):
396
+ x = self.conv_layers[i](x)
397
+ x = x.transpose(1, 2)
398
+ mb, seq = x.size()[:2]
399
+ return x.contiguous().view(mb, seq, -1).size(-1)
400
+
401
+ def validate_transformer_config(self, transformer_config):
402
+ for config in transformer_config:
403
+ input_dim, num_heads = config[:2]
404
+ if input_dim % num_heads != 0:
405
+ msg = (
406
+ "ERROR in transformer config {}: ".format(config)
407
+ + "input dimension {} ".format(input_dim)
408
+ + "not dividable by number of heads {}".format(num_heads)
409
+ )
410
+ raise ValueError(msg)
411
+
412
+ def parse_transformer_context(self, transformer_context):
413
+ """
414
+ transformer_context can be the following:
415
+ - None; indicates no context is used, i.e.,
416
+ transformer can access full context
417
+ - a tuple/list of two int; indicates left and right context,
418
+ any number <0 indicates infinite context
419
+ * e.g., (5, 6) indicates that for query at x_t, transformer can
420
+ access [t-5, t+6] (inclusive)
421
+ * e.g., (-1, 6) indicates that for query at x_t, transformer can
422
+ access [0, t+6] (inclusive)
423
+ """
424
+ if transformer_context is None:
425
+ return None
426
+
427
+ if not isinstance(transformer_context, Iterable):
428
+ raise ValueError("transformer context must be Iterable if it is not None")
429
+
430
+ if len(transformer_context) != 2:
431
+ raise ValueError("transformer context must have length 2")
432
+
433
+ left_context = transformer_context[0]
434
+ if left_context < 0:
435
+ left_context = None
436
+
437
+ right_context = transformer_context[1]
438
+ if right_context < 0:
439
+ right_context = None
440
+
441
+ if left_context is None and right_context is None:
442
+ return None
443
+
444
+ return (left_context, right_context)
445
+
446
+ def parse_transformer_sampling(self, transformer_sampling, num_layers):
447
+ """
448
+ parsing transformer sampling configuration
449
+
450
+ Args:
451
+ - transformer_sampling, accepted input:
452
+ * None, indicating no sampling
453
+ * an Iterable with int (>0) as element
454
+ - num_layers, expected number of transformer layers, must match with
455
+ the length of transformer_sampling if it is not None
456
+
457
+ Returns:
458
+ - A tuple with length num_layers
459
+ """
460
+ if transformer_sampling is None:
461
+ return (1,) * num_layers
462
+
463
+ if not isinstance(transformer_sampling, Iterable):
464
+ raise ValueError(
465
+ "transformer_sampling must be an iterable if it is not None"
466
+ )
467
+
468
+ if len(transformer_sampling) != num_layers:
469
+ raise ValueError(
470
+ "transformer_sampling {} does not match with the number "
471
+ "of layers {}".format(transformer_sampling, num_layers)
472
+ )
473
+
474
+ for layer, value in enumerate(transformer_sampling):
475
+ if not isinstance(value, int):
476
+ raise ValueError("Invalid value in transformer_sampling: ")
477
+ if value < 1:
478
+ raise ValueError(
479
+ "{} layer's subsampling is {}.".format(layer, value)
480
+ + " This is not allowed! "
481
+ )
482
+ return transformer_sampling
483
+
484
+ def slice(self, embedding, padding_mask, attn_mask, sampling_factor):
485
+ """
486
+ embedding is a (T, B, D) tensor
487
+ padding_mask is a (B, T) tensor or None
488
+ attn_mask is a (T, T) tensor or None
489
+ """
490
+ embedding = embedding[::sampling_factor, :, :]
491
+ if padding_mask is not None:
492
+ padding_mask = padding_mask[:, ::sampling_factor]
493
+ if attn_mask is not None:
494
+ attn_mask = attn_mask[::sampling_factor, ::sampling_factor]
495
+
496
+ return embedding, padding_mask, attn_mask
497
+
498
+ def lengths_to_attn_mask(self, input_lengths, subsampling_factor=1):
499
+ """
500
+ create attention mask according to sequence lengths and transformer
501
+ context
502
+
503
+ Args:
504
+ - input_lengths: (B, )-shape Int/Long tensor; input_lengths[b] is
505
+ the length of b-th sequence
506
+ - subsampling_factor: int
507
+ * Note that the left_context and right_context is specified in
508
+ the input frame-level while input to transformer may already
509
+ go through subsampling (e.g., the use of striding in vggblock)
510
+ we use subsampling_factor to scale the left/right context
511
+
512
+ Return:
513
+ - a (T, T) binary tensor or None, where T is max(input_lengths)
514
+ * if self.transformer_context is None, None
515
+ * if left_context is None,
516
+ * attn_mask[t, t + right_context + 1:] = 1
517
+ * others = 0
518
+ * if right_context is None,
519
+ * attn_mask[t, 0:t - left_context] = 1
520
+ * others = 0
521
+ * elsif
522
+ * attn_mask[t, t - left_context: t + right_context + 1] = 0
523
+ * others = 1
524
+ """
525
+ if self.transformer_context is None:
526
+ return None
527
+
528
+ maxT = torch.max(input_lengths).item()
529
+ attn_mask = torch.zeros(maxT, maxT)
530
+
531
+ left_context = self.transformer_context[0]
532
+ right_context = self.transformer_context[1]
533
+ if left_context is not None:
534
+ left_context = math.ceil(self.transformer_context[0] / subsampling_factor)
535
+ if right_context is not None:
536
+ right_context = math.ceil(self.transformer_context[1] / subsampling_factor)
537
+
538
+ for t in range(maxT):
539
+ if left_context is not None:
540
+ st = 0
541
+ en = max(st, t - left_context)
542
+ attn_mask[t, st:en] = 1
543
+ if right_context is not None:
544
+ st = t + right_context + 1
545
+ st = min(st, maxT - 1)
546
+ attn_mask[t, st:] = 1
547
+
548
+ return attn_mask.to(input_lengths.device)
549
+
550
+ def reorder_encoder_out(self, encoder_out, new_order):
551
+ encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
552
+ 1, new_order
553
+ )
554
+ if encoder_out["encoder_padding_mask"] is not None:
555
+ encoder_out["encoder_padding_mask"] = encoder_out[
556
+ "encoder_padding_mask"
557
+ ].index_select(1, new_order)
558
+ return encoder_out
559
+
560
+
561
+ class TransformerDecoder(FairseqIncrementalDecoder):
562
+ """
563
+ Transformer decoder consisting of *args.decoder_layers* layers. Each layer
564
+ is a :class:`TransformerDecoderLayer`.
565
+ Args:
566
+ args (argparse.Namespace): parsed command-line arguments
567
+ dictionary (~fairseq.data.Dictionary): decoding dictionary
568
+ embed_tokens (torch.nn.Embedding): output embedding
569
+ no_encoder_attn (bool, optional): whether to attend to encoder outputs.
570
+ Default: ``False``
571
+ left_pad (bool, optional): whether the input is left-padded. Default:
572
+ ``False``
573
+ """
574
+
575
+ def __init__(
576
+ self,
577
+ dictionary,
578
+ embed_dim=512,
579
+ transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG,
580
+ conv_config=DEFAULT_DEC_CONV_CONFIG,
581
+ encoder_output_dim=512,
582
+ ):
583
+
584
+ super().__init__(dictionary)
585
+ vocab_size = len(dictionary)
586
+ self.padding_idx = dictionary.pad()
587
+ self.embed_tokens = Embedding(vocab_size, embed_dim, self.padding_idx)
588
+
589
+ self.conv_layers = nn.ModuleList()
590
+ for i in range(len(conv_config)):
591
+ out_channels, kernel_size, layer_norm = conv_config[i]
592
+ if i == 0:
593
+ conv_layer = LinearizedConv1d(
594
+ embed_dim, out_channels, kernel_size, padding=kernel_size - 1
595
+ )
596
+ else:
597
+ conv_layer = LinearizedConv1d(
598
+ conv_config[i - 1][0],
599
+ out_channels,
600
+ kernel_size,
601
+ padding=kernel_size - 1,
602
+ )
603
+ self.conv_layers.append(conv_layer)
604
+ if layer_norm:
605
+ self.conv_layers.append(nn.LayerNorm(out_channels))
606
+ self.conv_layers.append(nn.ReLU())
607
+
608
+ self.layers = nn.ModuleList()
609
+ if conv_config[-1][0] != transformer_config[0][0]:
610
+ self.layers.append(Linear(conv_config[-1][0], transformer_config[0][0]))
611
+ self.layers.append(
612
+ TransformerDecoderLayer(
613
+ prepare_transformer_decoder_params(*transformer_config[0])
614
+ )
615
+ )
616
+
617
+ for i in range(1, len(transformer_config)):
618
+ if transformer_config[i - 1][0] != transformer_config[i][0]:
619
+ self.layers.append(
620
+ Linear(transformer_config[i - 1][0], transformer_config[i][0])
621
+ )
622
+ self.layers.append(
623
+ TransformerDecoderLayer(
624
+ prepare_transformer_decoder_params(*transformer_config[i])
625
+ )
626
+ )
627
+ self.fc_out = Linear(transformer_config[-1][0], vocab_size)
628
+
629
+ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
630
+ """
631
+ Args:
632
+ prev_output_tokens (LongTensor): previous decoder outputs of shape
633
+ `(batch, tgt_len)`, for input feeding/teacher forcing
634
+ encoder_out (Tensor, optional): output from the encoder, used for
635
+ encoder-side attention
636
+ incremental_state (dict): dictionary used for storing state during
637
+ :ref:`Incremental decoding`
638
+ Returns:
639
+ tuple:
640
+ - the last decoder layer's output of shape `(batch, tgt_len,
641
+ vocab)`
642
+ - the last decoder layer's attention weights of shape `(batch,
643
+ tgt_len, src_len)`
644
+ """
645
+ target_padding_mask = (
646
+ (prev_output_tokens == self.padding_idx).to(prev_output_tokens.device)
647
+ if incremental_state is None
648
+ else None
649
+ )
650
+
651
+ if incremental_state is not None:
652
+ prev_output_tokens = prev_output_tokens[:, -1:]
653
+
654
+ # embed tokens
655
+ x = self.embed_tokens(prev_output_tokens)
656
+
657
+ # B x T x C -> T x B x C
658
+ x = self._transpose_if_training(x, incremental_state)
659
+
660
+ for layer in self.conv_layers:
661
+ if isinstance(layer, LinearizedConvolution):
662
+ x = layer(x, incremental_state)
663
+ else:
664
+ x = layer(x)
665
+
666
+ # B x T x C -> T x B x C
667
+ x = self._transpose_if_inference(x, incremental_state)
668
+
669
+ # decoder layers
670
+ for layer in self.layers:
671
+ if isinstance(layer, TransformerDecoderLayer):
672
+ x, *_ = layer(
673
+ x,
674
+ (encoder_out["encoder_out"] if encoder_out is not None else None),
675
+ (
676
+ encoder_out["encoder_padding_mask"].t()
677
+ if encoder_out["encoder_padding_mask"] is not None
678
+ else None
679
+ ),
680
+ incremental_state,
681
+ self_attn_mask=(
682
+ self.buffered_future_mask(x)
683
+ if incremental_state is None
684
+ else None
685
+ ),
686
+ self_attn_padding_mask=(
687
+ target_padding_mask if incremental_state is None else None
688
+ ),
689
+ )
690
+ else:
691
+ x = layer(x)
692
+
693
+ # T x B x C -> B x T x C
694
+ x = x.transpose(0, 1)
695
+
696
+ x = self.fc_out(x)
697
+
698
+ return x, None
699
+
700
+ def buffered_future_mask(self, tensor):
701
+ dim = tensor.size(0)
702
+ if (
703
+ not hasattr(self, "_future_mask")
704
+ or self._future_mask is None
705
+ or self._future_mask.device != tensor.device
706
+ ):
707
+ self._future_mask = torch.triu(
708
+ utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
709
+ )
710
+ if self._future_mask.size(0) < dim:
711
+ self._future_mask = torch.triu(
712
+ utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1
713
+ )
714
+ return self._future_mask[:dim, :dim]
715
+
716
+ def _transpose_if_training(self, x, incremental_state):
717
+ if incremental_state is None:
718
+ x = x.transpose(0, 1)
719
+ return x
720
+
721
+ def _transpose_if_inference(self, x, incremental_state):
722
+ if incremental_state:
723
+ x = x.transpose(0, 1)
724
+ return x
725
+
726
+
727
+ @register_model("asr_vggtransformer_encoder")
728
+ class VGGTransformerEncoderModel(FairseqEncoderModel):
729
+ def __init__(self, encoder):
730
+ super().__init__(encoder)
731
+
732
+ @staticmethod
733
+ def add_args(parser):
734
+ """Add model-specific arguments to the parser."""
735
+ parser.add_argument(
736
+ "--input-feat-per-channel",
737
+ type=int,
738
+ metavar="N",
739
+ help="encoder input dimension per input channel",
740
+ )
741
+ parser.add_argument(
742
+ "--vggblock-enc-config",
743
+ type=str,
744
+ metavar="EXPR",
745
+ help="""
746
+ an array of tuples each containing the configuration of one vggblock
747
+ [(out_channels, conv_kernel_size, pooling_kernel_size,num_conv_layers), ...]
748
+ """,
749
+ )
750
+ parser.add_argument(
751
+ "--transformer-enc-config",
752
+ type=str,
753
+ metavar="EXPR",
754
+ help="""
755
+ a tuple containing the configuration of the Transformer layers
756
+ configurations:
757
+ [(input_dim,
758
+ num_heads,
759
+ ffn_dim,
760
+ normalize_before,
761
+ dropout,
762
+ attention_dropout,
763
+ relu_dropout), ]""",
764
+ )
765
+ parser.add_argument(
766
+ "--enc-output-dim",
767
+ type=int,
768
+ metavar="N",
769
+ help="encoder output dimension, projecting the LSTM output",
770
+ )
771
+ parser.add_argument(
772
+ "--in-channels",
773
+ type=int,
774
+ metavar="N",
775
+ help="number of encoder input channels",
776
+ )
777
+ parser.add_argument(
778
+ "--transformer-context",
779
+ type=str,
780
+ metavar="EXPR",
781
+ help="""
782
+ either None or a tuple of two ints, indicating left/right context a
783
+ transformer can have access to""",
784
+ )
785
+ parser.add_argument(
786
+ "--transformer-sampling",
787
+ type=str,
788
+ metavar="EXPR",
789
+ help="""
790
+ either None or a tuple of ints, indicating sampling factor in each layer""",
791
+ )
792
+
793
+ @classmethod
794
+ def build_model(cls, args, task):
795
+ """Build a new model instance."""
796
+ base_architecture_enconly(args)
797
+ encoder = VGGTransformerEncoderOnly(
798
+ vocab_size=len(task.target_dictionary),
799
+ input_feat_per_channel=args.input_feat_per_channel,
800
+ vggblock_config=eval(args.vggblock_enc_config),
801
+ transformer_config=eval(args.transformer_enc_config),
802
+ encoder_output_dim=args.enc_output_dim,
803
+ in_channels=args.in_channels,
804
+ transformer_context=eval(args.transformer_context),
805
+ transformer_sampling=eval(args.transformer_sampling),
806
+ )
807
+ return cls(encoder)
808
+
809
+ def get_normalized_probs(self, net_output, log_probs, sample=None):
810
+ # net_output['encoder_out'] is a (T, B, D) tensor
811
+ lprobs = super().get_normalized_probs(net_output, log_probs, sample)
812
+ # lprobs is a (T, B, D) tensor
813
+ # we need to transoose to get (B, T, D) tensor
814
+ lprobs = lprobs.transpose(0, 1).contiguous()
815
+ lprobs.batch_first = True
816
+ return lprobs
817
+
818
+
819
+ class VGGTransformerEncoderOnly(VGGTransformerEncoder):
820
+ def __init__(
821
+ self,
822
+ vocab_size,
823
+ input_feat_per_channel,
824
+ vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG,
825
+ transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG,
826
+ encoder_output_dim=512,
827
+ in_channels=1,
828
+ transformer_context=None,
829
+ transformer_sampling=None,
830
+ ):
831
+ super().__init__(
832
+ input_feat_per_channel=input_feat_per_channel,
833
+ vggblock_config=vggblock_config,
834
+ transformer_config=transformer_config,
835
+ encoder_output_dim=encoder_output_dim,
836
+ in_channels=in_channels,
837
+ transformer_context=transformer_context,
838
+ transformer_sampling=transformer_sampling,
839
+ )
840
+ self.fc_out = Linear(self.encoder_output_dim, vocab_size)
841
+
842
+ def forward(self, src_tokens, src_lengths, **kwargs):
843
+ """
844
+ src_tokens: padded tensor (B, T, C * feat)
845
+ src_lengths: tensor of original lengths of input utterances (B,)
846
+ """
847
+
848
+ enc_out = super().forward(src_tokens, src_lengths)
849
+ x = self.fc_out(enc_out["encoder_out"])
850
+ # x = F.log_softmax(x, dim=-1)
851
+ # Note: no need this line, because model.get_normalized_prob will call
852
+ # log_softmax
853
+ return {
854
+ "encoder_out": x, # (T, B, C)
855
+ "encoder_padding_mask": enc_out["encoder_padding_mask"], # (T, B)
856
+ }
857
+
858
+ def max_positions(self):
859
+ """Maximum input length supported by the encoder."""
860
+ return (1e6, 1e6) # an arbitrary large number
861
+
862
+
863
+ def Embedding(num_embeddings, embedding_dim, padding_idx):
864
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
865
+ # nn.init.uniform_(m.weight, -0.1, 0.1)
866
+ # nn.init.constant_(m.weight[padding_idx], 0)
867
+ return m
868
+
869
+
870
+ def Linear(in_features, out_features, bias=True, dropout=0):
871
+ """Linear layer (input: N x T x C)"""
872
+ m = nn.Linear(in_features, out_features, bias=bias)
873
+ # m.weight.data.uniform_(-0.1, 0.1)
874
+ # if bias:
875
+ # m.bias.data.uniform_(-0.1, 0.1)
876
+ return m
877
+
878
+
879
+ def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
880
+ """Weight-normalized Conv1d layer optimized for decoding"""
881
+ m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs)
882
+ std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
883
+ nn.init.normal_(m.weight, mean=0, std=std)
884
+ nn.init.constant_(m.bias, 0)
885
+ return nn.utils.weight_norm(m, dim=2)
886
+
887
+
888
+ def LayerNorm(embedding_dim):
889
+ m = nn.LayerNorm(embedding_dim)
890
+ return m
891
+
892
+
893
+ # seq2seq models
894
+ def base_architecture(args):
895
+ args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 40)
896
+ args.vggblock_enc_config = getattr(
897
+ args, "vggblock_enc_config", DEFAULT_ENC_VGGBLOCK_CONFIG
898
+ )
899
+ args.transformer_enc_config = getattr(
900
+ args, "transformer_enc_config", DEFAULT_ENC_TRANSFORMER_CONFIG
901
+ )
902
+ args.enc_output_dim = getattr(args, "enc_output_dim", 512)
903
+ args.in_channels = getattr(args, "in_channels", 1)
904
+ args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128)
905
+ args.transformer_dec_config = getattr(
906
+ args, "transformer_dec_config", DEFAULT_ENC_TRANSFORMER_CONFIG
907
+ )
908
+ args.conv_dec_config = getattr(args, "conv_dec_config", DEFAULT_DEC_CONV_CONFIG)
909
+ args.transformer_context = getattr(args, "transformer_context", "None")
910
+
911
+
912
+ @register_model_architecture("asr_vggtransformer", "vggtransformer_1")
913
+ def vggtransformer_1(args):
914
+ args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
915
+ args.vggblock_enc_config = getattr(
916
+ args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
917
+ )
918
+ args.transformer_enc_config = getattr(
919
+ args,
920
+ "transformer_enc_config",
921
+ "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 14",
922
+ )
923
+ args.enc_output_dim = getattr(args, "enc_output_dim", 1024)
924
+ args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128)
925
+ args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4")
926
+ args.transformer_dec_config = getattr(
927
+ args,
928
+ "transformer_dec_config",
929
+ "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 4",
930
+ )
931
+
932
+
933
+ @register_model_architecture("asr_vggtransformer", "vggtransformer_2")
934
+ def vggtransformer_2(args):
935
+ args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
936
+ args.vggblock_enc_config = getattr(
937
+ args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
938
+ )
939
+ args.transformer_enc_config = getattr(
940
+ args,
941
+ "transformer_enc_config",
942
+ "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16",
943
+ )
944
+ args.enc_output_dim = getattr(args, "enc_output_dim", 1024)
945
+ args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512)
946
+ args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4")
947
+ args.transformer_dec_config = getattr(
948
+ args,
949
+ "transformer_dec_config",
950
+ "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 6",
951
+ )
952
+
953
+
954
+ @register_model_architecture("asr_vggtransformer", "vggtransformer_base")
955
+ def vggtransformer_base(args):
956
+ args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
957
+ args.vggblock_enc_config = getattr(
958
+ args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
959
+ )
960
+ args.transformer_enc_config = getattr(
961
+ args, "transformer_enc_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 12"
962
+ )
963
+
964
+ args.enc_output_dim = getattr(args, "enc_output_dim", 512)
965
+ args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512)
966
+ args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4")
967
+ args.transformer_dec_config = getattr(
968
+ args, "transformer_dec_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 6"
969
+ )
970
+ # Size estimations:
971
+ # Encoder:
972
+ # - vggblock param: 64*1*3*3 + 64*64*3*3 + 128*64*3*3 + 128*128*3 = 258K
973
+ # Transformer:
974
+ # - input dimension adapter: 2560 x 512 -> 1.31M
975
+ # - transformer_layers (x12) --> 37.74M
976
+ # * MultiheadAttention: 512*512*3 (in_proj) + 512*512 (out_proj) = 1.048M
977
+ # * FFN weight: 512*2048*2 = 2.097M
978
+ # - output dimension adapter: 512 x 512 -> 0.26 M
979
+ # Decoder:
980
+ # - LinearizedConv1d: 512 * 256 * 3 + 256 * 256 * 3 * 3
981
+ # - transformer_layer: (x6) --> 25.16M
982
+ # * MultiheadAttention (self-attention): 512*512*3 + 512*512 = 1.048M
983
+ # * MultiheadAttention (encoder-attention): 512*512*3 + 512*512 = 1.048M
984
+ # * FFN: 512*2048*2 = 2.097M
985
+ # Final FC:
986
+ # - FC: 512*5000 = 256K (assuming vocab size 5K)
987
+ # In total:
988
+ # ~65 M
989
+
990
+
991
+ # CTC models
992
+ def base_architecture_enconly(args):
993
+ args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 40)
994
+ args.vggblock_enc_config = getattr(
995
+ args, "vggblock_enc_config", "[(32, 3, 2, 2, True)] * 2"
996
+ )
997
+ args.transformer_enc_config = getattr(
998
+ args, "transformer_enc_config", "((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2"
999
+ )
1000
+ args.enc_output_dim = getattr(args, "enc_output_dim", 512)
1001
+ args.in_channels = getattr(args, "in_channels", 1)
1002
+ args.transformer_context = getattr(args, "transformer_context", "None")
1003
+ args.transformer_sampling = getattr(args, "transformer_sampling", "None")
1004
+
1005
+
1006
+ @register_model_architecture("asr_vggtransformer_encoder", "vggtransformer_enc_1")
1007
+ def vggtransformer_enc_1(args):
1008
+ # vggtransformer_1 is the same as vggtransformer_enc_big, except the number
1009
+ # of layers is increased to 16
1010
+ # keep it here for backward compatiablity purpose
1011
+ args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
1012
+ args.vggblock_enc_config = getattr(
1013
+ args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
1014
+ )
1015
+ args.transformer_enc_config = getattr(
1016
+ args,
1017
+ "transformer_enc_config",
1018
+ "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16",
1019
+ )
1020
+ args.enc_output_dim = getattr(args, "enc_output_dim", 1024)
fairseq/examples/speech_recognition/models/w2l_conv_glu_enc.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import math
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from fairseq.models import (
14
+ FairseqEncoder,
15
+ FairseqEncoderModel,
16
+ register_model,
17
+ register_model_architecture,
18
+ )
19
+ from fairseq.modules.fairseq_dropout import FairseqDropout
20
+
21
+
22
+ default_conv_enc_config = """[
23
+ (400, 13, 170, 0.2),
24
+ (440, 14, 0, 0.214),
25
+ (484, 15, 0, 0.22898),
26
+ (532, 16, 0, 0.2450086),
27
+ (584, 17, 0, 0.262159202),
28
+ (642, 18, 0, 0.28051034614),
29
+ (706, 19, 0, 0.30014607037),
30
+ (776, 20, 0, 0.321156295296),
31
+ (852, 21, 0, 0.343637235966),
32
+ (936, 22, 0, 0.367691842484),
33
+ (1028, 23, 0, 0.393430271458),
34
+ (1130, 24, 0, 0.42097039046),
35
+ (1242, 25, 0, 0.450438317792),
36
+ (1366, 26, 0, 0.481969000038),
37
+ (1502, 27, 0, 0.51570683004),
38
+ (1652, 28, 0, 0.551806308143),
39
+ (1816, 29, 0, 0.590432749713),
40
+ ]"""
41
+
42
+
43
+ @register_model("asr_w2l_conv_glu_encoder")
44
+ class W2lConvGluEncoderModel(FairseqEncoderModel):
45
+ def __init__(self, encoder):
46
+ super().__init__(encoder)
47
+
48
+ @staticmethod
49
+ def add_args(parser):
50
+ """Add model-specific arguments to the parser."""
51
+ parser.add_argument(
52
+ "--input-feat-per-channel",
53
+ type=int,
54
+ metavar="N",
55
+ help="encoder input dimension per input channel",
56
+ )
57
+ parser.add_argument(
58
+ "--in-channels",
59
+ type=int,
60
+ metavar="N",
61
+ help="number of encoder input channels",
62
+ )
63
+ parser.add_argument(
64
+ "--conv-enc-config",
65
+ type=str,
66
+ metavar="EXPR",
67
+ help="""
68
+ an array of tuples each containing the configuration of one conv layer
69
+ [(out_channels, kernel_size, padding, dropout), ...]
70
+ """,
71
+ )
72
+
73
+ @classmethod
74
+ def build_model(cls, args, task):
75
+ """Build a new model instance."""
76
+ conv_enc_config = getattr(args, "conv_enc_config", default_conv_enc_config)
77
+ encoder = W2lConvGluEncoder(
78
+ vocab_size=len(task.target_dictionary),
79
+ input_feat_per_channel=args.input_feat_per_channel,
80
+ in_channels=args.in_channels,
81
+ conv_enc_config=eval(conv_enc_config),
82
+ )
83
+ return cls(encoder)
84
+
85
+ def get_normalized_probs(self, net_output, log_probs, sample=None):
86
+ lprobs = super().get_normalized_probs(net_output, log_probs, sample)
87
+ lprobs.batch_first = False
88
+ return lprobs
89
+
90
+
91
+ class W2lConvGluEncoder(FairseqEncoder):
92
+ def __init__(
93
+ self, vocab_size, input_feat_per_channel, in_channels, conv_enc_config
94
+ ):
95
+ super().__init__(None)
96
+
97
+ self.input_dim = input_feat_per_channel
98
+ if in_channels != 1:
99
+ raise ValueError("only 1 input channel is currently supported")
100
+
101
+ self.conv_layers = nn.ModuleList()
102
+ self.linear_layers = nn.ModuleList()
103
+ self.dropouts = []
104
+ cur_channels = input_feat_per_channel
105
+
106
+ for out_channels, kernel_size, padding, dropout in conv_enc_config:
107
+ layer = nn.Conv1d(cur_channels, out_channels, kernel_size, padding=padding)
108
+ layer.weight.data.mul_(math.sqrt(3)) # match wav2letter init
109
+ self.conv_layers.append(nn.utils.weight_norm(layer))
110
+ self.dropouts.append(
111
+ FairseqDropout(dropout, module_name=self.__class__.__name__)
112
+ )
113
+ if out_channels % 2 != 0:
114
+ raise ValueError("odd # of out_channels is incompatible with GLU")
115
+ cur_channels = out_channels // 2 # halved by GLU
116
+
117
+ for out_channels in [2 * cur_channels, vocab_size]:
118
+ layer = nn.Linear(cur_channels, out_channels)
119
+ layer.weight.data.mul_(math.sqrt(3))
120
+ self.linear_layers.append(nn.utils.weight_norm(layer))
121
+ cur_channels = out_channels // 2
122
+
123
+ def forward(self, src_tokens, src_lengths, **kwargs):
124
+
125
+ """
126
+ src_tokens: padded tensor (B, T, C * feat)
127
+ src_lengths: tensor of original lengths of input utterances (B,)
128
+ """
129
+ B, T, _ = src_tokens.size()
130
+ x = src_tokens.transpose(1, 2).contiguous() # (B, feat, T) assuming C == 1
131
+
132
+ for layer_idx in range(len(self.conv_layers)):
133
+ x = self.conv_layers[layer_idx](x)
134
+ x = F.glu(x, dim=1)
135
+ x = self.dropouts[layer_idx](x)
136
+
137
+ x = x.transpose(1, 2).contiguous() # (B, T, 908)
138
+ x = self.linear_layers[0](x)
139
+ x = F.glu(x, dim=2)
140
+ x = self.dropouts[-1](x)
141
+ x = self.linear_layers[1](x)
142
+
143
+ assert x.size(0) == B
144
+ assert x.size(1) == T
145
+
146
+ encoder_out = x.transpose(0, 1) # (T, B, vocab_size)
147
+
148
+ # need to debug this -- find a simpler/elegant way in pytorch APIs
149
+ encoder_padding_mask = (
150
+ torch.arange(T).view(1, T).expand(B, -1).to(x.device)
151
+ >= src_lengths.view(B, 1).expand(-1, T)
152
+ ).t() # (B x T) -> (T x B)
153
+
154
+ return {
155
+ "encoder_out": encoder_out, # (T, B, vocab_size)
156
+ "encoder_padding_mask": encoder_padding_mask, # (T, B)
157
+ }
158
+
159
+ def reorder_encoder_out(self, encoder_out, new_order):
160
+ encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
161
+ 1, new_order
162
+ )
163
+ encoder_out["encoder_padding_mask"] = encoder_out[
164
+ "encoder_padding_mask"
165
+ ].index_select(1, new_order)
166
+ return encoder_out
167
+
168
+ def max_positions(self):
169
+ """Maximum input length supported by the encoder."""
170
+ return (1e6, 1e6) # an arbitrary large number
171
+
172
+
173
+ @register_model_architecture("asr_w2l_conv_glu_encoder", "w2l_conv_glu_enc")
174
+ def w2l_conv_glu_enc(args):
175
+ args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
176
+ args.in_channels = getattr(args, "in_channels", 1)
177
+ args.conv_enc_config = getattr(args, "conv_enc_config", default_conv_enc_config)
fairseq/examples/speech_recognition/new/README.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Flashlight Decoder
2
+
3
+ This script runs decoding for pre-trained speech recognition models.
4
+
5
+ ## Usage
6
+
7
+ Assuming a few variables:
8
+
9
+ ```bash
10
+ checkpoint=<path-to-checkpoint>
11
+ data=<path-to-data-directory>
12
+ lm_model=<path-to-language-model>
13
+ lexicon=<path-to-lexicon>
14
+ ```
15
+
16
+ Example usage for decoding a fine-tuned Wav2Vec model:
17
+
18
+ ```bash
19
+ python $FAIRSEQ_ROOT/examples/speech_recognition/new/infer.py --multirun \
20
+ task=audio_pretraining \
21
+ task.data=$data \
22
+ task.labels=ltr \
23
+ common_eval.path=$checkpoint \
24
+ decoding.type=kenlm \
25
+ decoding.lexicon=$lexicon \
26
+ decoding.lmpath=$lm_model \
27
+ dataset.gen_subset=dev_clean,dev_other,test_clean,test_other
28
+ ```
29
+
30
+ Example usage for using Ax to sweep WER parameters (requires `pip install hydra-ax-sweeper`):
31
+
32
+ ```bash
33
+ python $FAIRSEQ_ROOT/examples/speech_recognition/new/infer.py --multirun \
34
+ hydra/sweeper=ax \
35
+ task=audio_pretraining \
36
+ task.data=$data \
37
+ task.labels=ltr \
38
+ common_eval.path=$checkpoint \
39
+ decoding.type=kenlm \
40
+ decoding.lexicon=$lexicon \
41
+ decoding.lmpath=$lm_model \
42
+ dataset.gen_subset=dev_other
43
+ ```
fairseq/examples/speech_recognition/new/__init__.py ADDED
File without changes
fairseq/examples/speech_recognition/new/conf/hydra/sweeper/ax.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package hydra.sweeper
2
+ _target_: hydra_plugins.hydra_ax_sweeper.ax_sweeper.AxSweeper
3
+ max_batch_size: null
4
+ ax_config:
5
+ max_trials: 128
6
+ early_stop:
7
+ minimize: true
8
+ max_epochs_without_improvement: 10
9
+ epsilon: 0.025
10
+ experiment:
11
+ name: ${dataset.gen_subset}
12
+ objective_name: wer
13
+ minimize: true
14
+ parameter_constraints: null
15
+ outcome_constraints: null
16
+ status_quo: null
17
+ client:
18
+ verbose_logging: false
19
+ random_seed: null
20
+ params:
21
+ decoding.lmweight:
22
+ type: range
23
+ bounds: [0.0, 5.0]
24
+ decoding.wordscore:
25
+ type: range
26
+ bounds: [-5.0, 5.0]
27
+ decoding.silweight:
28
+ type: range
29
+ bounds: [ -8.0, 0.0 ]
fairseq/examples/speech_recognition/new/conf/hydra/sweeper/ax_sil.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package hydra.sweeper
2
+ _target_: hydra_plugins.hydra_ax_sweeper.ax_sweeper.AxSweeper
3
+ max_batch_size: null
4
+ ax_config:
5
+ max_trials: 64
6
+ early_stop:
7
+ minimize: true
8
+ max_epochs_without_improvement: 10
9
+ epsilon: 0.025
10
+ experiment:
11
+ name: ${dataset.gen_subset}
12
+ objective_name: wer
13
+ minimize: true
14
+ parameter_constraints: null
15
+ outcome_constraints: null
16
+ status_quo: null
17
+ client:
18
+ verbose_logging: false
19
+ random_seed: null
20
+ params:
21
+ decoding.lmweight:
22
+ type: range
23
+ bounds: [0.0, 10.0]
24
+ decoding.wordscore:
25
+ type: range
26
+ bounds: [-10.0, 10.0]
27
+ decoding.silweight:
28
+ type: range
29
+ bounds: [ -10.0, 0.0 ]
fairseq/examples/speech_recognition/new/conf/infer.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ defaults:
4
+ - task: null
5
+ - model: null
6
+
7
+ hydra:
8
+ run:
9
+ dir: ${common_eval.results_path}/${dataset.gen_subset}
10
+ sweep:
11
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${common_eval.results_path}
12
+ subdir: ${dataset.gen_subset}
13
+ common:
14
+ user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
15
+ common_eval:
16
+ results_path: null
17
+ path: null
18
+ post_process: letter
19
+ quiet: true
20
+ dataset:
21
+ max_tokens: 3000000
22
+ gen_subset: test
23
+ distributed_training:
24
+ distributed_world_size: 1
25
+ decoding:
26
+ beam: 5
27
+ type: viterbi
fairseq/examples/speech_recognition/new/conf/run_config/fb_slurm_1.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '/'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ - common_eval.path
13
+ sweep:
14
+ dir: /checkpoint/abaevski/asr/d2v2/decoding/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
15
+ # subdir: ${hydra.job.override_dirname}
16
+ launcher:
17
+ cpus_per_task: 16
18
+ gpus_per_node: 1
19
+ tasks_per_node: 1
20
+ nodes: 1
21
+ partition: devlab,learnlab
22
+ mem_gb: 100
23
+ timeout_min: 2000
24
+ max_num_timeout: 10
25
+ name: ${env:PREFIX}_${hydra.job.config_name}
26
+ submitit_folder: ${hydra.sweep.dir}/%j
27
+ constraint: volta32gb
28
+ exclude: learnfair7598
fairseq/examples/speech_recognition/new/conf/run_config/fb_slurm_2g.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '/'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ - common_eval.path
13
+ sweep:
14
+ dir: /checkpoint/abaevski/asr/d2v2/decoding/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
15
+ # subdir: ${hydra.job.override_dirname}
16
+ launcher:
17
+ cpus_per_task: 16
18
+ gpus_per_node: 2
19
+ tasks_per_node: 2
20
+ nodes: 1
21
+ partition: devlab,learnlab
22
+ mem_gb: 100
23
+ timeout_min: 2000
24
+ max_num_timeout: 10
25
+ name: ${env:PREFIX}_${hydra.job.config_name}
26
+ submitit_folder: ${hydra.sweep.dir}/%j
27
+ constraint: volta32gb
fairseq/examples/speech_recognition/new/decoders/__init__.py ADDED
File without changes
fairseq/examples/speech_recognition/new/decoders/base_decoder.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools as it
7
+ from typing import Any, Dict, List
8
+
9
+ import torch
10
+ from fairseq.data.dictionary import Dictionary
11
+ from fairseq.models.fairseq_model import FairseqModel
12
+
13
+
14
+ class BaseDecoder:
15
+ def __init__(self, tgt_dict: Dictionary) -> None:
16
+ self.tgt_dict = tgt_dict
17
+ self.vocab_size = len(tgt_dict)
18
+
19
+ self.blank = (
20
+ tgt_dict.index("<ctc_blank>")
21
+ if "<ctc_blank>" in tgt_dict.indices
22
+ else tgt_dict.bos()
23
+ )
24
+ if "<sep>" in tgt_dict.indices:
25
+ self.silence = tgt_dict.index("<sep>")
26
+ elif "|" in tgt_dict.indices:
27
+ self.silence = tgt_dict.index("|")
28
+ else:
29
+ self.silence = tgt_dict.eos()
30
+
31
+ def generate(
32
+ self, models: List[FairseqModel], sample: Dict[str, Any], **unused
33
+ ) -> List[List[Dict[str, torch.LongTensor]]]:
34
+ encoder_input = {
35
+ k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
36
+ }
37
+ emissions = self.get_emissions(models, encoder_input)
38
+ return self.decode(emissions)
39
+
40
+ def get_emissions(
41
+ self,
42
+ models: List[FairseqModel],
43
+ encoder_input: Dict[str, Any],
44
+ ) -> torch.FloatTensor:
45
+ model = models[0]
46
+ encoder_out = model(**encoder_input)
47
+ if hasattr(model, "get_logits"):
48
+ emissions = model.get_logits(encoder_out)
49
+ else:
50
+ emissions = model.get_normalized_probs(encoder_out, log_probs=True)
51
+ return emissions.transpose(0, 1).float().cpu().contiguous()
52
+
53
+ def get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor:
54
+ idxs = (g[0] for g in it.groupby(idxs))
55
+ idxs = filter(lambda x: x != self.blank, idxs)
56
+ return torch.LongTensor(list(idxs))
57
+
58
+ def decode(
59
+ self,
60
+ emissions: torch.FloatTensor,
61
+ ) -> List[List[Dict[str, torch.LongTensor]]]:
62
+ raise NotImplementedError
fairseq/examples/speech_recognition/new/decoders/decoder.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ from typing import Union
9
+
10
+ from fairseq.data.dictionary import Dictionary
11
+
12
+ from .decoder_config import DecoderConfig, FlashlightDecoderConfig
13
+ from .base_decoder import BaseDecoder
14
+
15
+
16
+ def Decoder(
17
+ cfg: Union[DecoderConfig, FlashlightDecoderConfig], tgt_dict: Dictionary
18
+ ) -> BaseDecoder:
19
+
20
+ if cfg.type == "viterbi":
21
+ from .viterbi_decoder import ViterbiDecoder
22
+
23
+ return ViterbiDecoder(tgt_dict)
24
+ if cfg.type == "kenlm":
25
+ from .flashlight_decoder import KenLMDecoder
26
+
27
+ return KenLMDecoder(cfg, tgt_dict)
28
+ if cfg.type == "fairseqlm":
29
+ from .flashlight_decoder import FairseqLMDecoder
30
+
31
+ return FairseqLMDecoder(cfg, tgt_dict)
32
+ raise NotImplementedError(f"Invalid decoder name: {cfg.name}")
fairseq/examples/speech_recognition/new/decoders/decoder_config.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from dataclasses import dataclass, field
8
+ from typing import Optional
9
+
10
+ from fairseq.dataclass.configs import FairseqDataclass
11
+ from fairseq.dataclass.constants import ChoiceEnum
12
+ from omegaconf import MISSING
13
+
14
+
15
+ DECODER_CHOICES = ChoiceEnum(["viterbi", "kenlm", "fairseqlm"])
16
+
17
+
18
+ @dataclass
19
+ class DecoderConfig(FairseqDataclass):
20
+ type: DECODER_CHOICES = field(
21
+ default="viterbi",
22
+ metadata={"help": "The type of decoder to use"},
23
+ )
24
+
25
+
26
+ @dataclass
27
+ class FlashlightDecoderConfig(FairseqDataclass):
28
+ nbest: int = field(
29
+ default=1,
30
+ metadata={"help": "Number of decodings to return"},
31
+ )
32
+ unitlm: bool = field(
33
+ default=False,
34
+ metadata={"help": "If set, use unit language model"},
35
+ )
36
+ lmpath: str = field(
37
+ default=MISSING,
38
+ metadata={"help": "Language model for KenLM decoder"},
39
+ )
40
+ lexicon: Optional[str] = field(
41
+ default=None,
42
+ metadata={"help": "Lexicon for Flashlight decoder"},
43
+ )
44
+ beam: int = field(
45
+ default=50,
46
+ metadata={"help": "Number of beams to use for decoding"},
47
+ )
48
+ beamthreshold: float = field(
49
+ default=50.0,
50
+ metadata={"help": "Threshold for beam search decoding"},
51
+ )
52
+ beamsizetoken: Optional[int] = field(
53
+ default=None, metadata={"help": "Beam size to use"}
54
+ )
55
+ wordscore: float = field(
56
+ default=-1,
57
+ metadata={"help": "Word score for KenLM decoder"},
58
+ )
59
+ unkweight: float = field(
60
+ default=-math.inf,
61
+ metadata={"help": "Unknown weight for KenLM decoder"},
62
+ )
63
+ silweight: float = field(
64
+ default=0,
65
+ metadata={"help": "Silence weight for KenLM decoder"},
66
+ )
67
+ lmweight: float = field(
68
+ default=2,
69
+ metadata={"help": "Weight for LM while interpolating score"},
70
+ )
fairseq/examples/speech_recognition/new/decoders/flashlight_decoder.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import gc
9
+ import os.path as osp
10
+ import warnings
11
+ from collections import deque, namedtuple
12
+ from typing import Any, Dict, Tuple
13
+
14
+ import numpy as np
15
+ import torch
16
+ from fairseq import tasks
17
+ from fairseq.data.dictionary import Dictionary
18
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
19
+ from fairseq.models.fairseq_model import FairseqModel
20
+ from fairseq.utils import apply_to_sample
21
+ from omegaconf import open_dict, OmegaConf
22
+
23
+ from typing import List
24
+
25
+ from .decoder_config import FlashlightDecoderConfig
26
+ from .base_decoder import BaseDecoder
27
+
28
+ try:
29
+ from flashlight.lib.text.decoder import (
30
+ LM,
31
+ CriterionType,
32
+ DecodeResult,
33
+ KenLM,
34
+ LexiconDecoder,
35
+ LexiconDecoderOptions,
36
+ LexiconFreeDecoder,
37
+ LexiconFreeDecoderOptions,
38
+ LMState,
39
+ SmearingMode,
40
+ Trie,
41
+ )
42
+ from flashlight.lib.text.dictionary import create_word_dict, load_words
43
+ from flashlight.lib.text.dictionary import Dictionary as flDictionary
44
+ except ImportError:
45
+ warnings.warn(
46
+ "flashlight python bindings are required to use this functionality. "
47
+ "Please install from "
48
+ "https://github.com/facebookresearch/flashlight/tree/master/bindings/python"
49
+ )
50
+ LM = object
51
+ LMState = object
52
+
53
+
54
+ class KenLMDecoder(BaseDecoder):
55
+ def __init__(self, cfg: FlashlightDecoderConfig, tgt_dict: Dictionary) -> None:
56
+ super().__init__(tgt_dict)
57
+
58
+ self.nbest = cfg.nbest
59
+ self.unitlm = cfg.unitlm
60
+
61
+ if cfg.lexicon:
62
+ self.lexicon = load_words(cfg.lexicon)
63
+ self.word_dict = create_word_dict(self.lexicon)
64
+ self.unk_word = self.word_dict.get_index("<unk>")
65
+
66
+ self.lm = KenLM(cfg.lmpath, self.word_dict)
67
+ self.trie = Trie(self.vocab_size, self.silence)
68
+
69
+ start_state = self.lm.start(False)
70
+ for word, spellings in self.lexicon.items():
71
+ word_idx = self.word_dict.get_index(word)
72
+ _, score = self.lm.score(start_state, word_idx)
73
+ for spelling in spellings:
74
+ spelling_idxs = [tgt_dict.index(token) for token in spelling]
75
+ assert (
76
+ tgt_dict.unk() not in spelling_idxs
77
+ ), f"{word} {spelling} {spelling_idxs}"
78
+ self.trie.insert(spelling_idxs, word_idx, score)
79
+ self.trie.smear(SmearingMode.MAX)
80
+
81
+ self.decoder_opts = LexiconDecoderOptions(
82
+ beam_size=cfg.beam,
83
+ beam_size_token=cfg.beamsizetoken or len(tgt_dict),
84
+ beam_threshold=cfg.beamthreshold,
85
+ lm_weight=cfg.lmweight,
86
+ word_score=cfg.wordscore,
87
+ unk_score=cfg.unkweight,
88
+ sil_score=cfg.silweight,
89
+ log_add=False,
90
+ criterion_type=CriterionType.CTC,
91
+ )
92
+
93
+ self.decoder = LexiconDecoder(
94
+ self.decoder_opts,
95
+ self.trie,
96
+ self.lm,
97
+ self.silence,
98
+ self.blank,
99
+ self.unk_word,
100
+ [],
101
+ self.unitlm,
102
+ )
103
+ else:
104
+ assert self.unitlm, "Lexicon-free decoding requires unit LM"
105
+
106
+ self.word_dict = flDictionary()
107
+ for sym in tgt_dict.symbols:
108
+ self.word_dict.add_entry(sym, tgt_dict.index(sym))
109
+ self.lm = KenLM(cfg.lmpath, self.word_dict)
110
+ self.decoder_opts = LexiconFreeDecoderOptions(
111
+ beam_size=cfg.beam,
112
+ beam_size_token=cfg.beamsizetoken or len(tgt_dict),
113
+ beam_threshold=cfg.beamthreshold,
114
+ lm_weight=cfg.lmweight,
115
+ sil_score=cfg.silweight,
116
+ log_add=False,
117
+ criterion_type=CriterionType.CTC,
118
+ )
119
+ self.decoder = LexiconFreeDecoder(
120
+ self.decoder_opts, self.lm, self.silence, self.blank, []
121
+ )
122
+
123
+ def get_timesteps(self, token_idxs: List[int]) -> List[int]:
124
+ """Returns frame numbers corresponding to every non-blank token.
125
+
126
+ Parameters
127
+ ----------
128
+ token_idxs : List[int]
129
+ IDs of decoded tokens.
130
+
131
+ Returns
132
+ -------
133
+ List[int]
134
+ Frame numbers corresponding to every non-blank token.
135
+ """
136
+ timesteps = []
137
+ for i, token_idx in enumerate(token_idxs):
138
+ if token_idx == self.blank:
139
+ continue
140
+ if i == 0 or token_idx != token_idxs[i-1]:
141
+ timesteps.append(i)
142
+ return timesteps
143
+
144
+ def decode(
145
+ self,
146
+ emissions: torch.FloatTensor,
147
+ ) -> List[List[Dict[str, torch.LongTensor]]]:
148
+ B, T, N = emissions.size()
149
+ hypos = []
150
+ for b in range(B):
151
+ emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
152
+ results = self.decoder.decode(emissions_ptr, T, N)
153
+
154
+ nbest_results = results[: self.nbest]
155
+ hypos.append(
156
+ [
157
+ {
158
+ "tokens": self.get_tokens(result.tokens),
159
+ "score": result.score,
160
+ "timesteps": self.get_timesteps(result.tokens),
161
+ "words": [
162
+ self.word_dict.get_entry(x) for x in result.words if x >= 0
163
+ ],
164
+ }
165
+ for result in nbest_results
166
+ ]
167
+ )
168
+ return hypos
169
+
170
+
171
+ FairseqLMState = namedtuple(
172
+ "FairseqLMState",
173
+ [
174
+ "prefix",
175
+ "incremental_state",
176
+ "probs",
177
+ ],
178
+ )
179
+
180
+
181
+ class FairseqLM(LM):
182
+ def __init__(self, dictionary: Dictionary, model: FairseqModel) -> None:
183
+ super().__init__()
184
+
185
+ self.dictionary = dictionary
186
+ self.model = model
187
+ self.unk = self.dictionary.unk()
188
+
189
+ self.save_incremental = False # this currently does not work properly
190
+ self.max_cache = 20_000
191
+
192
+ if torch.cuda.is_available():
193
+ model.cuda()
194
+ model.eval()
195
+ model.make_generation_fast_()
196
+
197
+ self.states = {}
198
+ self.stateq = deque()
199
+
200
+ def start(self, start_with_nothing: bool) -> LMState:
201
+ state = LMState()
202
+ prefix = torch.LongTensor([[self.dictionary.eos()]])
203
+ incremental_state = {} if self.save_incremental else None
204
+ with torch.no_grad():
205
+ res = self.model(prefix.cuda(), incremental_state=incremental_state)
206
+ probs = self.model.get_normalized_probs(res, log_probs=True, sample=None)
207
+
208
+ if incremental_state is not None:
209
+ incremental_state = apply_to_sample(lambda x: x.cpu(), incremental_state)
210
+ self.states[state] = FairseqLMState(
211
+ prefix.numpy(), incremental_state, probs[0, -1].cpu().numpy()
212
+ )
213
+ self.stateq.append(state)
214
+
215
+ return state
216
+
217
+ def score(
218
+ self,
219
+ state: LMState,
220
+ token_index: int,
221
+ no_cache: bool = False,
222
+ ) -> Tuple[LMState, int]:
223
+ """
224
+ Evaluate language model based on the current lm state and new word
225
+ Parameters:
226
+ -----------
227
+ state: current lm state
228
+ token_index: index of the word
229
+ (can be lexicon index then you should store inside LM the
230
+ mapping between indices of lexicon and lm, or lm index of a word)
231
+ Returns:
232
+ --------
233
+ (LMState, float): pair of (new state, score for the current word)
234
+ """
235
+ curr_state = self.states[state]
236
+
237
+ def trim_cache(targ_size: int) -> None:
238
+ while len(self.stateq) > targ_size:
239
+ rem_k = self.stateq.popleft()
240
+ rem_st = self.states[rem_k]
241
+ rem_st = FairseqLMState(rem_st.prefix, None, None)
242
+ self.states[rem_k] = rem_st
243
+
244
+ if curr_state.probs is None:
245
+ new_incremental_state = (
246
+ curr_state.incremental_state.copy()
247
+ if curr_state.incremental_state is not None
248
+ else None
249
+ )
250
+ with torch.no_grad():
251
+ if new_incremental_state is not None:
252
+ new_incremental_state = apply_to_sample(
253
+ lambda x: x.cuda(), new_incremental_state
254
+ )
255
+ elif self.save_incremental:
256
+ new_incremental_state = {}
257
+
258
+ res = self.model(
259
+ torch.from_numpy(curr_state.prefix).cuda(),
260
+ incremental_state=new_incremental_state,
261
+ )
262
+ probs = self.model.get_normalized_probs(
263
+ res, log_probs=True, sample=None
264
+ )
265
+
266
+ if new_incremental_state is not None:
267
+ new_incremental_state = apply_to_sample(
268
+ lambda x: x.cpu(), new_incremental_state
269
+ )
270
+
271
+ curr_state = FairseqLMState(
272
+ curr_state.prefix, new_incremental_state, probs[0, -1].cpu().numpy()
273
+ )
274
+
275
+ if not no_cache:
276
+ self.states[state] = curr_state
277
+ self.stateq.append(state)
278
+
279
+ score = curr_state.probs[token_index].item()
280
+
281
+ trim_cache(self.max_cache)
282
+
283
+ outstate = state.child(token_index)
284
+ if outstate not in self.states and not no_cache:
285
+ prefix = np.concatenate(
286
+ [curr_state.prefix, torch.LongTensor([[token_index]])], -1
287
+ )
288
+ incr_state = curr_state.incremental_state
289
+
290
+ self.states[outstate] = FairseqLMState(prefix, incr_state, None)
291
+
292
+ if token_index == self.unk:
293
+ score = float("-inf")
294
+
295
+ return outstate, score
296
+
297
+ def finish(self, state: LMState) -> Tuple[LMState, int]:
298
+ """
299
+ Evaluate eos for language model based on the current lm state
300
+ Returns:
301
+ --------
302
+ (LMState, float): pair of (new state, score for the current word)
303
+ """
304
+ return self.score(state, self.dictionary.eos())
305
+
306
+ def empty_cache(self) -> None:
307
+ self.states = {}
308
+ self.stateq = deque()
309
+ gc.collect()
310
+
311
+
312
+ class FairseqLMDecoder(BaseDecoder):
313
+ def __init__(self, cfg: FlashlightDecoderConfig, tgt_dict: Dictionary) -> None:
314
+ super().__init__(tgt_dict)
315
+
316
+ self.nbest = cfg.nbest
317
+ self.unitlm = cfg.unitlm
318
+
319
+ self.lexicon = load_words(cfg.lexicon) if cfg.lexicon else None
320
+ self.idx_to_wrd = {}
321
+
322
+ checkpoint = torch.load(cfg.lmpath, map_location="cpu")
323
+
324
+ if "cfg" in checkpoint and checkpoint["cfg"] is not None:
325
+ lm_args = checkpoint["cfg"]
326
+ else:
327
+ lm_args = convert_namespace_to_omegaconf(checkpoint["args"])
328
+
329
+ if not OmegaConf.is_dict(lm_args):
330
+ lm_args = OmegaConf.create(lm_args)
331
+
332
+ with open_dict(lm_args.task):
333
+ lm_args.task.data = osp.dirname(cfg.lmpath)
334
+
335
+ task = tasks.setup_task(lm_args.task)
336
+ model = task.build_model(lm_args.model)
337
+ model.load_state_dict(checkpoint["model"], strict=False)
338
+
339
+ self.trie = Trie(self.vocab_size, self.silence)
340
+
341
+ self.word_dict = task.dictionary
342
+ self.unk_word = self.word_dict.unk()
343
+ self.lm = FairseqLM(self.word_dict, model)
344
+
345
+ if self.lexicon:
346
+ start_state = self.lm.start(False)
347
+ for i, (word, spellings) in enumerate(self.lexicon.items()):
348
+ if self.unitlm:
349
+ word_idx = i
350
+ self.idx_to_wrd[i] = word
351
+ score = 0
352
+ else:
353
+ word_idx = self.word_dict.index(word)
354
+ _, score = self.lm.score(start_state, word_idx, no_cache=True)
355
+
356
+ for spelling in spellings:
357
+ spelling_idxs = [tgt_dict.index(token) for token in spelling]
358
+ assert (
359
+ tgt_dict.unk() not in spelling_idxs
360
+ ), f"{spelling} {spelling_idxs}"
361
+ self.trie.insert(spelling_idxs, word_idx, score)
362
+ self.trie.smear(SmearingMode.MAX)
363
+
364
+ self.decoder_opts = LexiconDecoderOptions(
365
+ beam_size=cfg.beam,
366
+ beam_size_token=cfg.beamsizetoken or len(tgt_dict),
367
+ beam_threshold=cfg.beamthreshold,
368
+ lm_weight=cfg.lmweight,
369
+ word_score=cfg.wordscore,
370
+ unk_score=cfg.unkweight,
371
+ sil_score=cfg.silweight,
372
+ log_add=False,
373
+ criterion_type=CriterionType.CTC,
374
+ )
375
+
376
+ self.decoder = LexiconDecoder(
377
+ self.decoder_opts,
378
+ self.trie,
379
+ self.lm,
380
+ self.silence,
381
+ self.blank,
382
+ self.unk_word,
383
+ [],
384
+ self.unitlm,
385
+ )
386
+ else:
387
+ assert self.unitlm, "Lexicon-free decoding requires unit LM"
388
+
389
+ d = {w: [[w]] for w in tgt_dict.symbols}
390
+ self.word_dict = create_word_dict(d)
391
+ self.lm = KenLM(cfg.lmpath, self.word_dict)
392
+ self.decoder_opts = LexiconFreeDecoderOptions(
393
+ beam_size=cfg.beam,
394
+ beam_size_token=cfg.beamsizetoken or len(tgt_dict),
395
+ beam_threshold=cfg.beamthreshold,
396
+ lm_weight=cfg.lmweight,
397
+ sil_score=cfg.silweight,
398
+ log_add=False,
399
+ criterion_type=CriterionType.CTC,
400
+ )
401
+ self.decoder = LexiconFreeDecoder(
402
+ self.decoder_opts, self.lm, self.silence, self.blank, []
403
+ )
404
+
405
+ def decode(
406
+ self,
407
+ emissions: torch.FloatTensor,
408
+ ) -> List[List[Dict[str, torch.LongTensor]]]:
409
+ B, T, N = emissions.size()
410
+ hypos = []
411
+
412
+ def make_hypo(result: DecodeResult) -> Dict[str, Any]:
413
+ hypo = {
414
+ "tokens": self.get_tokens(result.tokens),
415
+ "score": result.score,
416
+ }
417
+ if self.lexicon:
418
+ hypo["words"] = [
419
+ self.idx_to_wrd[x] if self.unitlm else self.word_dict[x]
420
+ for x in result.words
421
+ if x >= 0
422
+ ]
423
+ return hypo
424
+
425
+ for b in range(B):
426
+ emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
427
+ results = self.decoder.decode(emissions_ptr, T, N)
428
+
429
+ nbest_results = results[: self.nbest]
430
+ hypos.append([make_hypo(result) for result in nbest_results])
431
+ self.lm.empty_cache()
432
+
433
+ return hypos
fairseq/examples/speech_recognition/new/decoders/viterbi_decoder.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import torch
9
+
10
+ from typing import List, Dict
11
+
12
+ from .base_decoder import BaseDecoder
13
+
14
+
15
+ class ViterbiDecoder(BaseDecoder):
16
+ def decode(
17
+ self,
18
+ emissions: torch.FloatTensor,
19
+ ) -> List[List[Dict[str, torch.LongTensor]]]:
20
+ def get_pred(e):
21
+ score = e.log_softmax(dim=-1).max(dim=-1)[0].sum()
22
+ toks = e.argmax(dim=-1).unique_consecutive()
23
+ return {"tokens":toks[toks != self.blank], "score":score}
24
+ return [[get_pred(x)] for x in emissions]
fairseq/examples/speech_recognition/new/infer.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import ast
8
+ import hashlib
9
+ import logging
10
+ import os
11
+ import shutil
12
+ import sys
13
+ import re
14
+ from dataclasses import dataclass, field, is_dataclass
15
+ from pathlib import Path
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import editdistance
19
+ import torch
20
+ import torch.distributed as dist
21
+ from examples.speech_recognition.new.decoders.decoder_config import (
22
+ DecoderConfig,
23
+ FlashlightDecoderConfig,
24
+ )
25
+ from examples.speech_recognition.new.decoders.decoder import Decoder
26
+ from fairseq import checkpoint_utils, distributed_utils, progress_bar, tasks, utils
27
+ from fairseq.data.data_utils import post_process
28
+ from fairseq.dataclass.configs import (
29
+ CheckpointConfig,
30
+ CommonConfig,
31
+ CommonEvalConfig,
32
+ DatasetConfig,
33
+ DistributedTrainingConfig,
34
+ FairseqDataclass,
35
+ )
36
+ from fairseq.logging.meters import StopwatchMeter, TimeMeter
37
+ from fairseq.logging.progress_bar import BaseProgressBar
38
+ from fairseq.models.fairseq_model import FairseqModel
39
+ from omegaconf import OmegaConf
40
+
41
+ import hydra
42
+ from hydra.core.config_store import ConfigStore
43
+
44
+ logging.root.setLevel(logging.INFO)
45
+ logging.basicConfig(level=logging.INFO)
46
+ logger = logging.getLogger(__name__)
47
+
48
+ config_path = Path(__file__).resolve().parent / "conf"
49
+
50
+
51
+ @dataclass
52
+ class DecodingConfig(DecoderConfig, FlashlightDecoderConfig):
53
+ unique_wer_file: bool = field(
54
+ default=False,
55
+ metadata={"help": "If set, use a unique file for storing WER"},
56
+ )
57
+ results_path: Optional[str] = field(
58
+ default=None,
59
+ metadata={
60
+ "help": "If set, write hypothesis and reference sentences into this directory"
61
+ },
62
+ )
63
+
64
+
65
+ @dataclass
66
+ class InferConfig(FairseqDataclass):
67
+ task: Any = None
68
+ decoding: DecodingConfig = DecodingConfig()
69
+ common: CommonConfig = CommonConfig()
70
+ common_eval: CommonEvalConfig = CommonEvalConfig()
71
+ checkpoint: CheckpointConfig = CheckpointConfig()
72
+ distributed_training: DistributedTrainingConfig = DistributedTrainingConfig()
73
+ dataset: DatasetConfig = DatasetConfig()
74
+ is_ax: bool = field(
75
+ default=False,
76
+ metadata={
77
+ "help": "if true, assumes we are using ax for tuning and returns a tuple for ax to consume"
78
+ },
79
+ )
80
+
81
+
82
+ def reset_logging():
83
+ root = logging.getLogger()
84
+ for handler in root.handlers:
85
+ root.removeHandler(handler)
86
+ root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper())
87
+ handler = logging.StreamHandler(sys.stdout)
88
+ handler.setFormatter(
89
+ logging.Formatter(
90
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
91
+ datefmt="%Y-%m-%d %H:%M:%S",
92
+ )
93
+ )
94
+ root.addHandler(handler)
95
+
96
+
97
+ class InferenceProcessor:
98
+ cfg: InferConfig
99
+
100
+ def __init__(self, cfg: InferConfig) -> None:
101
+ self.cfg = cfg
102
+ self.task = tasks.setup_task(cfg.task)
103
+
104
+ models, saved_cfg = self.load_model_ensemble()
105
+
106
+ ### LOAD ADAPTER ####
107
+ ckpt_obj = checkpoint_utils.load_checkpoint_to_cpu(self.cfg.common_eval.path)
108
+ if "adapter" in ckpt_obj:
109
+ target_lang = self.cfg.dataset.gen_subset.split(":")[0]
110
+ assert target_lang in ckpt_obj["adapter"]
111
+
112
+ logger.info(f">>> LOADING ADAPTER: {target_lang}")
113
+ ft_obj = ckpt_obj["adapter"][target_lang]
114
+ ft_model = ft_obj["model"]
115
+ cdevice = models[0].w2v_encoder.proj.weight.device
116
+ cdtype = models[0].w2v_encoder.proj.weight.dtype
117
+ ft_proj_out, ft_proj_in = ft_model["w2v_encoder.proj.weight"].shape
118
+ ft_proj = torch.nn.Linear(ft_proj_in, ft_proj_out, bias=True)
119
+ ft_proj.to(device=cdevice, dtype=cdtype)
120
+ models[0].w2v_encoder.proj = ft_proj
121
+ with torch.no_grad():
122
+ for kk, vv in models[0].named_parameters():
123
+ if kk in ft_model:
124
+ vv.copy_(ft_model[kk])
125
+ self.task.load_state_dict(ft_obj["task_state"])
126
+ # overwrite gen_subset with master config
127
+ self.cfg.dataset.gen_subset = re.sub('^[\w-]+:', saved_cfg['task']['multi_corpus_keys']+":", self.cfg.dataset.gen_subset)
128
+ self.models = models
129
+ self.saved_cfg = saved_cfg
130
+ self.tgt_dict = self.task.target_dictionary
131
+
132
+ self.task.load_dataset(
133
+ self.cfg.dataset.gen_subset,
134
+ task_cfg=saved_cfg.task,
135
+ )
136
+ self.generator = Decoder(cfg.decoding, self.tgt_dict)
137
+ self.gen_timer = StopwatchMeter()
138
+ self.wps_meter = TimeMeter()
139
+ self.num_sentences = 0
140
+ self.total_errors = 0
141
+ self.total_length = 0
142
+
143
+ self.hypo_words_file = None
144
+ self.hypo_units_file = None
145
+ self.ref_words_file = None
146
+ self.ref_units_file = None
147
+ self.score_file = None
148
+
149
+ self.progress_bar = self.build_progress_bar()
150
+
151
+ def __enter__(self) -> "InferenceProcessor":
152
+ if self.cfg.decoding.results_path is not None:
153
+ self.hypo_words_file = self.get_res_file("hypo.word")
154
+ self.hypo_units_file = self.get_res_file("hypo.units")
155
+ self.ref_words_file = self.get_res_file("ref.word")
156
+ self.ref_units_file = self.get_res_file("ref.units")
157
+ self.score_file = self.get_res_file("asr_score")
158
+ return self
159
+
160
+ def __exit__(self, *exc) -> bool:
161
+ if self.cfg.decoding.results_path is not None:
162
+ self.hypo_words_file.close()
163
+ self.hypo_units_file.close()
164
+ self.ref_words_file.close()
165
+ self.ref_units_file.close()
166
+ self.score_file.close()
167
+ return False
168
+
169
+ def __iter__(self) -> Any:
170
+ for sample in self.progress_bar:
171
+ if not self.cfg.common.cpu:
172
+ sample = utils.move_to_cuda(sample)
173
+
174
+ # Happens on the last batch.
175
+ if "net_input" not in sample:
176
+ continue
177
+ yield sample
178
+
179
+ def log(self, *args, **kwargs):
180
+ self.progress_bar.log(*args, **kwargs)
181
+
182
+ def print(self, *args, **kwargs):
183
+ self.progress_bar.print(*args, **kwargs)
184
+
185
+ def get_res_file(self, fname: str) -> None:
186
+ fname = os.path.join(self.cfg.decoding.results_path, fname)
187
+ if self.data_parallel_world_size > 1:
188
+ fname = f"{fname}.{self.data_parallel_rank}"
189
+ return open(fname, "w", buffering=1)
190
+
191
+ def merge_shards(self) -> None:
192
+ """Merges all shard files into shard 0, then removes shard suffix."""
193
+
194
+ shard_id = self.data_parallel_rank
195
+ num_shards = self.data_parallel_world_size
196
+
197
+ if self.data_parallel_world_size > 1:
198
+
199
+ def merge_shards_with_root(fname: str) -> None:
200
+ fname = os.path.join(self.cfg.decoding.results_path, fname)
201
+ logger.info("Merging %s on shard %d", fname, shard_id)
202
+ base_fpath = Path(f"{fname}.0")
203
+ with open(base_fpath, "a") as out_file:
204
+ for s in range(1, num_shards):
205
+ shard_fpath = Path(f"{fname}.{s}")
206
+ with open(shard_fpath, "r") as in_file:
207
+ for line in in_file:
208
+ out_file.write(line)
209
+ shard_fpath.unlink()
210
+ shutil.move(f"{fname}.0", fname)
211
+
212
+ dist.barrier() # ensure all shards finished writing
213
+ if shard_id == (0 % num_shards):
214
+ merge_shards_with_root("hypo.word")
215
+ if shard_id == (1 % num_shards):
216
+ merge_shards_with_root("hypo.units")
217
+ if shard_id == (2 % num_shards):
218
+ merge_shards_with_root("ref.word")
219
+ if shard_id == (3 % num_shards):
220
+ merge_shards_with_root("ref.units")
221
+ dist.barrier()
222
+
223
+ def optimize_model(self, model: FairseqModel) -> None:
224
+ model.make_generation_fast_()
225
+ if self.cfg.common.fp16:
226
+ model.half()
227
+ if not self.cfg.common.cpu:
228
+ model.cuda()
229
+
230
+ def load_model_ensemble(self) -> Tuple[List[FairseqModel], FairseqDataclass]:
231
+ arg_overrides = ast.literal_eval(self.cfg.common_eval.model_overrides)
232
+ models, saved_cfg = checkpoint_utils.load_model_ensemble(
233
+ utils.split_paths(self.cfg.common_eval.path, separator="\\"),
234
+ arg_overrides=arg_overrides,
235
+ task=self.task,
236
+ suffix=self.cfg.checkpoint.checkpoint_suffix,
237
+ strict=(self.cfg.checkpoint.checkpoint_shard_count == 1),
238
+ num_shards=self.cfg.checkpoint.checkpoint_shard_count,
239
+ )
240
+ for model in models:
241
+ self.optimize_model(model)
242
+ return models, saved_cfg
243
+
244
+ def get_dataset_itr(self, disable_iterator_cache: bool = False) -> None:
245
+ return self.task.get_batch_iterator(
246
+ dataset=self.task.dataset(self.cfg.dataset.gen_subset),
247
+ max_tokens=self.cfg.dataset.max_tokens,
248
+ max_sentences=self.cfg.dataset.batch_size,
249
+ max_positions=(sys.maxsize, sys.maxsize),
250
+ ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test,
251
+ required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple,
252
+ seed=self.cfg.common.seed,
253
+ num_shards=self.data_parallel_world_size,
254
+ shard_id=self.data_parallel_rank,
255
+ num_workers=self.cfg.dataset.num_workers,
256
+ data_buffer_size=self.cfg.dataset.data_buffer_size,
257
+ disable_iterator_cache=disable_iterator_cache,
258
+ ).next_epoch_itr(shuffle=False)
259
+
260
+ def build_progress_bar(
261
+ self,
262
+ epoch: Optional[int] = None,
263
+ prefix: Optional[str] = None,
264
+ default_log_format: str = "tqdm",
265
+ ) -> BaseProgressBar:
266
+ return progress_bar.progress_bar(
267
+ iterator=self.get_dataset_itr(),
268
+ log_format=self.cfg.common.log_format,
269
+ log_interval=self.cfg.common.log_interval,
270
+ epoch=epoch,
271
+ prefix=prefix,
272
+ tensorboard_logdir=self.cfg.common.tensorboard_logdir,
273
+ default_log_format=default_log_format,
274
+ )
275
+
276
+ @property
277
+ def data_parallel_world_size(self):
278
+ if self.cfg.distributed_training.distributed_world_size == 1:
279
+ return 1
280
+ return distributed_utils.get_data_parallel_world_size()
281
+
282
+ @property
283
+ def data_parallel_rank(self):
284
+ if self.cfg.distributed_training.distributed_world_size == 1:
285
+ return 0
286
+ return distributed_utils.get_data_parallel_rank()
287
+
288
+ def process_sentence(
289
+ self,
290
+ sample: Dict[str, Any],
291
+ hypo: Dict[str, Any],
292
+ sid: int,
293
+ batch_id: int,
294
+ ) -> Tuple[int, int]:
295
+ speaker = None # Speaker can't be parsed from dataset.
296
+ if "target_label" in sample:
297
+ toks = sample["target_label"]
298
+ else:
299
+ toks = sample["target"]
300
+ toks = toks[batch_id, :]
301
+
302
+ # Processes hypothesis.
303
+ hyp_pieces = self.tgt_dict.string(hypo["tokens"].int().cpu())
304
+ if "words" in hypo:
305
+ hyp_words = " ".join(hypo["words"])
306
+ else:
307
+ hyp_words = post_process(hyp_pieces, self.cfg.common_eval.post_process)
308
+
309
+ # Processes target.
310
+ target_tokens = utils.strip_pad(toks, self.tgt_dict.pad())
311
+ tgt_pieces = self.tgt_dict.string(target_tokens.int().cpu())
312
+ tgt_words = post_process(tgt_pieces, self.cfg.common_eval.post_process)
313
+
314
+ if self.cfg.decoding.results_path is not None:
315
+ print(f"{hyp_pieces} ({speaker}-{sid})", file=self.hypo_units_file)
316
+ print(f"{hyp_words} ({speaker}-{sid})", file=self.hypo_words_file)
317
+ print(f"{tgt_pieces} ({speaker}-{sid})", file=self.ref_units_file)
318
+ print(f"{tgt_words} ({speaker}-{sid})", file=self.ref_words_file)
319
+ print(f"{hypo['score'].item()} ({speaker}-{sid})", file=self.score_file)
320
+
321
+ if not self.cfg.common_eval.quiet:
322
+ logger.info(f"HYPO: {hyp_words}")
323
+ logger.info(f"REF: {tgt_words}")
324
+ logger.info("---------------------")
325
+
326
+ hyp_words, tgt_words = hyp_words.split(), tgt_words.split()
327
+
328
+ return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
329
+
330
+ def process_sample(self, sample: Dict[str, Any]) -> None:
331
+ self.gen_timer.start()
332
+ hypos = self.task.inference_step(
333
+ generator=self.generator,
334
+ models=self.models,
335
+ sample=sample,
336
+ )
337
+ num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
338
+ self.gen_timer.stop(num_generated_tokens)
339
+ self.wps_meter.update(num_generated_tokens)
340
+
341
+ for batch_id, sample_id in enumerate(sample["id"].tolist()):
342
+ errs, length = self.process_sentence(
343
+ sample=sample,
344
+ sid=sample_id,
345
+ batch_id=batch_id,
346
+ hypo=hypos[batch_id][0],
347
+ )
348
+ self.total_errors += errs
349
+ self.total_length += length
350
+
351
+ self.log({"wps": round(self.wps_meter.avg)})
352
+ if "nsentences" in sample:
353
+ self.num_sentences += sample["nsentences"]
354
+ else:
355
+ self.num_sentences += sample["id"].numel()
356
+
357
+ def log_generation_time(self) -> None:
358
+ logger.info(
359
+ "Processed %d sentences (%d tokens) in %.1fs %.2f "
360
+ "sentences per second, %.2f tokens per second)",
361
+ self.num_sentences,
362
+ self.gen_timer.n,
363
+ self.gen_timer.sum,
364
+ self.num_sentences / (self.gen_timer.sum + 1e-6),
365
+ 1.0 / (self.gen_timer.avg + 1e-6),
366
+ )
367
+
368
+
369
+ def parse_wer(wer_file: Path) -> float:
370
+ with open(wer_file, "r") as f:
371
+ return float(f.readline().strip().split(" ")[1])
372
+
373
+
374
+ def get_wer_file(cfg: InferConfig) -> Path:
375
+ """Hashes the decoding parameters to a unique file ID."""
376
+ base_path = "wer"
377
+ if cfg.decoding.results_path is not None:
378
+ base_path = os.path.join(cfg.decoding.results_path, base_path)
379
+
380
+ if cfg.decoding.unique_wer_file:
381
+ yaml_str = OmegaConf.to_yaml(cfg.decoding)
382
+ fid = int(hashlib.md5(yaml_str.encode("utf-8")).hexdigest(), 16)
383
+ return Path(f"{base_path}.{fid % 1000000}")
384
+ else:
385
+ return Path(base_path)
386
+
387
+
388
+ def main(cfg: InferConfig) -> float:
389
+ """Entry point for main processing logic.
390
+
391
+ Args:
392
+ cfg: The inferance configuration to use.
393
+ wer: Optional shared memory pointer for returning the WER. If not None,
394
+ the final WER value will be written here instead of being returned.
395
+
396
+ Returns:
397
+ The final WER if `wer` is None, otherwise None.
398
+ """
399
+
400
+ yaml_str, wer_file = OmegaConf.to_yaml(cfg.decoding), get_wer_file(cfg)
401
+
402
+ # Validates the provided configuration.
403
+ if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
404
+ cfg.dataset.max_tokens = 4000000
405
+ if not cfg.common.cpu and not torch.cuda.is_available():
406
+ raise ValueError("CUDA not found; set `cpu=True` to run without CUDA")
407
+
408
+ logger.info(cfg.common_eval.path)
409
+
410
+ with InferenceProcessor(cfg) as processor:
411
+ for sample in processor:
412
+ processor.process_sample(sample)
413
+
414
+ processor.log_generation_time()
415
+
416
+ if cfg.decoding.results_path is not None:
417
+ processor.merge_shards()
418
+
419
+ errs_t, leng_t = processor.total_errors, processor.total_length
420
+
421
+ if cfg.common.cpu:
422
+ logger.warning("Merging WER requires CUDA.")
423
+ elif processor.data_parallel_world_size > 1:
424
+ stats = torch.LongTensor([errs_t, leng_t]).cuda()
425
+ dist.all_reduce(stats, op=dist.ReduceOp.SUM)
426
+ errs_t, leng_t = stats[0].item(), stats[1].item()
427
+
428
+ wer = errs_t * 100.0 / leng_t
429
+
430
+ if distributed_utils.is_master(cfg.distributed_training):
431
+ with open(wer_file, "w") as f:
432
+ f.write(
433
+ (
434
+ f"WER: {wer}\n"
435
+ f"err / num_ref_words = {errs_t} / {leng_t}\n\n"
436
+ f"{yaml_str}"
437
+ )
438
+ )
439
+
440
+ return wer
441
+
442
+
443
+ @hydra.main(config_path=config_path, config_name="infer")
444
+ def hydra_main(cfg: InferConfig) -> Union[float, Tuple[float, Optional[float]]]:
445
+ container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)
446
+ cfg = OmegaConf.create(container)
447
+ OmegaConf.set_struct(cfg, True)
448
+
449
+ if cfg.common.reset_logging:
450
+ reset_logging()
451
+
452
+ utils.import_user_module(cfg.common)
453
+
454
+ # logger.info("Config:\n%s", OmegaConf.to_yaml(cfg))
455
+ wer = float("inf")
456
+
457
+ try:
458
+ if cfg.common.profile:
459
+ with torch.cuda.profiler.profile():
460
+ with torch.autograd.profiler.emit_nvtx():
461
+ distributed_utils.call_main(cfg, main)
462
+ else:
463
+ distributed_utils.call_main(cfg, main)
464
+
465
+ wer = parse_wer(get_wer_file(cfg))
466
+ except BaseException as e: # pylint: disable=broad-except
467
+ if not cfg.common.suppress_crashes:
468
+ raise
469
+ else:
470
+ logger.error("Crashed! %s", str(e))
471
+
472
+ logger.info("Word error rate: %.4f", wer)
473
+ if cfg.is_ax:
474
+ return wer, None
475
+
476
+ return wer
477
+
478
+
479
+ def cli_main() -> None:
480
+ try:
481
+ from hydra._internal.utils import (
482
+ get_args,
483
+ ) # pylint: disable=import-outside-toplevel
484
+
485
+ cfg_name = get_args().config_name or "infer"
486
+ except ImportError:
487
+ logger.warning("Failed to get config name from hydra args")
488
+ cfg_name = "infer"
489
+
490
+ cs = ConfigStore.instance()
491
+ cs.store(name=cfg_name, node=InferConfig)
492
+
493
+ for k in InferConfig.__dataclass_fields__:
494
+ if is_dataclass(InferConfig.__dataclass_fields__[k].type):
495
+ v = InferConfig.__dataclass_fields__[k].default
496
+ cs.store(name=k, node=v)
497
+
498
+ hydra_main() # pylint: disable=no-value-for-parameter
499
+
500
+
501
+ if __name__ == "__main__":
502
+ cli_main()
fairseq/examples/speech_recognition/tasks/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+
4
+
5
+ for file in sorted(os.listdir(os.path.dirname(__file__))):
6
+ if file.endswith(".py") and not file.startswith("_"):
7
+ task_name = file[: file.find(".py")]
8
+ importlib.import_module("examples.speech_recognition.tasks." + task_name)
fairseq/examples/speech_recognition/tasks/speech_recognition.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import os
8
+ import re
9
+ import sys
10
+
11
+ import torch
12
+ from examples.speech_recognition.data import AsrDataset
13
+ from examples.speech_recognition.data.replabels import replabel_symbol
14
+ from fairseq.data import Dictionary
15
+ from fairseq.tasks import LegacyFairseqTask, register_task
16
+
17
+
18
+ def get_asr_dataset_from_json(data_json_path, tgt_dict):
19
+ """
20
+ Parse data json and create dataset.
21
+ See scripts/asr_prep_json.py which pack json from raw files
22
+
23
+ Json example:
24
+ {
25
+ "utts": {
26
+ "4771-29403-0025": {
27
+ "input": {
28
+ "length_ms": 170,
29
+ "path": "/tmp/file1.flac"
30
+ },
31
+ "output": {
32
+ "text": "HELLO \n",
33
+ "token": "HE LLO",
34
+ "tokenid": "4815, 861"
35
+ }
36
+ },
37
+ "1564-142299-0096": {
38
+ ...
39
+ }
40
+ }
41
+ """
42
+ if not os.path.isfile(data_json_path):
43
+ raise FileNotFoundError("Dataset not found: {}".format(data_json_path))
44
+ with open(data_json_path, "rb") as f:
45
+ data_samples = json.load(f)["utts"]
46
+ assert len(data_samples) != 0
47
+ sorted_samples = sorted(
48
+ data_samples.items(),
49
+ key=lambda sample: int(sample[1]["input"]["length_ms"]),
50
+ reverse=True,
51
+ )
52
+ aud_paths = [s[1]["input"]["path"] for s in sorted_samples]
53
+ ids = [s[0] for s in sorted_samples]
54
+ speakers = []
55
+ for s in sorted_samples:
56
+ m = re.search("(.+?)-(.+?)-(.+?)", s[0])
57
+ speakers.append(m.group(1) + "_" + m.group(2))
58
+ frame_sizes = [s[1]["input"]["length_ms"] for s in sorted_samples]
59
+ tgt = [
60
+ [int(i) for i in s[1]["output"]["tokenid"].split(", ")]
61
+ for s in sorted_samples
62
+ ]
63
+ # append eos
64
+ tgt = [[*t, tgt_dict.eos()] for t in tgt]
65
+ return AsrDataset(aud_paths, frame_sizes, tgt, tgt_dict, ids, speakers)
66
+
67
+
68
+ @register_task("speech_recognition")
69
+ class SpeechRecognitionTask(LegacyFairseqTask):
70
+ """
71
+ Task for training speech recognition model.
72
+ """
73
+
74
+ @staticmethod
75
+ def add_args(parser):
76
+ """Add task-specific arguments to the parser."""
77
+ parser.add_argument("data", help="path to data directory")
78
+ parser.add_argument(
79
+ "--silence-token", default="\u2581", help="token for silence (used by w2l)"
80
+ )
81
+ parser.add_argument(
82
+ "--max-source-positions",
83
+ default=sys.maxsize,
84
+ type=int,
85
+ metavar="N",
86
+ help="max number of frames in the source sequence",
87
+ )
88
+ parser.add_argument(
89
+ "--max-target-positions",
90
+ default=1024,
91
+ type=int,
92
+ metavar="N",
93
+ help="max number of tokens in the target sequence",
94
+ )
95
+
96
+ def __init__(self, args, tgt_dict):
97
+ super().__init__(args)
98
+ self.tgt_dict = tgt_dict
99
+
100
+ @classmethod
101
+ def setup_task(cls, args, **kwargs):
102
+ """Setup the task (e.g., load dictionaries)."""
103
+ dict_path = os.path.join(args.data, "dict.txt")
104
+ if not os.path.isfile(dict_path):
105
+ raise FileNotFoundError("Dict not found: {}".format(dict_path))
106
+ tgt_dict = Dictionary.load(dict_path)
107
+
108
+ if args.criterion == "ctc_loss":
109
+ tgt_dict.add_symbol("<ctc_blank>")
110
+ elif args.criterion == "asg_loss":
111
+ for i in range(1, args.max_replabel + 1):
112
+ tgt_dict.add_symbol(replabel_symbol(i))
113
+
114
+ print("| dictionary: {} types".format(len(tgt_dict)))
115
+ return cls(args, tgt_dict)
116
+
117
+ def load_dataset(self, split, combine=False, **kwargs):
118
+ """Load a given dataset split.
119
+
120
+ Args:
121
+ split (str): name of the split (e.g., train, valid, test)
122
+ """
123
+ data_json_path = os.path.join(self.args.data, "{}.json".format(split))
124
+ self.datasets[split] = get_asr_dataset_from_json(data_json_path, self.tgt_dict)
125
+
126
+ def build_generator(self, models, args, **unused):
127
+ w2l_decoder = getattr(args, "w2l_decoder", None)
128
+ if w2l_decoder == "viterbi":
129
+ from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
130
+
131
+ return W2lViterbiDecoder(args, self.target_dictionary)
132
+ elif w2l_decoder == "kenlm":
133
+ from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
134
+
135
+ return W2lKenLMDecoder(args, self.target_dictionary)
136
+ elif w2l_decoder == "fairseqlm":
137
+ from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder
138
+
139
+ return W2lFairseqLMDecoder(args, self.target_dictionary)
140
+ else:
141
+ return super().build_generator(models, args)
142
+
143
+ @property
144
+ def target_dictionary(self):
145
+ """Return the :class:`~fairseq.data.Dictionary` for the language
146
+ model."""
147
+ return self.tgt_dict
148
+
149
+ @property
150
+ def source_dictionary(self):
151
+ """Return the source :class:`~fairseq.data.Dictionary` (if applicable
152
+ for this task)."""
153
+ return None
154
+
155
+ def max_positions(self):
156
+ """Return the max speech and sentence length allowed by the task."""
157
+ return (self.args.max_source_positions, self.args.max_target_positions)
fairseq/examples/speech_recognition/utils/wer_utils.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ from __future__ import absolute_import, division, print_function, unicode_literals
9
+
10
+ import re
11
+ from collections import deque
12
+ from enum import Enum
13
+
14
+ import numpy as np
15
+
16
+
17
+ """
18
+ Utility modules for computation of Word Error Rate,
19
+ Alignments, as well as more granular metrics like
20
+ deletion, insersion and substitutions.
21
+ """
22
+
23
+
24
+ class Code(Enum):
25
+ match = 1
26
+ substitution = 2
27
+ insertion = 3
28
+ deletion = 4
29
+
30
+
31
+ class Token(object):
32
+ def __init__(self, lbl="", st=np.nan, en=np.nan):
33
+ if np.isnan(st):
34
+ self.label, self.start, self.end = "", 0.0, 0.0
35
+ else:
36
+ self.label, self.start, self.end = lbl, st, en
37
+
38
+
39
+ class AlignmentResult(object):
40
+ def __init__(self, refs, hyps, codes, score):
41
+ self.refs = refs # std::deque<int>
42
+ self.hyps = hyps # std::deque<int>
43
+ self.codes = codes # std::deque<Code>
44
+ self.score = score # float
45
+
46
+
47
+ def coordinate_to_offset(row, col, ncols):
48
+ return int(row * ncols + col)
49
+
50
+
51
+ def offset_to_row(offset, ncols):
52
+ return int(offset / ncols)
53
+
54
+
55
+ def offset_to_col(offset, ncols):
56
+ return int(offset % ncols)
57
+
58
+
59
+ def trimWhitespace(str):
60
+ return re.sub(" +", " ", re.sub(" *$", "", re.sub("^ *", "", str)))
61
+
62
+
63
+ def str2toks(str):
64
+ pieces = trimWhitespace(str).split(" ")
65
+ toks = []
66
+ for p in pieces:
67
+ toks.append(Token(p, 0.0, 0.0))
68
+ return toks
69
+
70
+
71
+ class EditDistance(object):
72
+ def __init__(self, time_mediated):
73
+ self.time_mediated_ = time_mediated
74
+ self.scores_ = np.nan # Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic>
75
+ self.backtraces_ = (
76
+ np.nan
77
+ ) # Eigen::Matrix<size_t, Eigen::Dynamic, Eigen::Dynamic> backtraces_;
78
+ self.confusion_pairs_ = {}
79
+
80
+ def cost(self, ref, hyp, code):
81
+ if self.time_mediated_:
82
+ if code == Code.match:
83
+ return abs(ref.start - hyp.start) + abs(ref.end - hyp.end)
84
+ elif code == Code.insertion:
85
+ return hyp.end - hyp.start
86
+ elif code == Code.deletion:
87
+ return ref.end - ref.start
88
+ else: # substitution
89
+ return abs(ref.start - hyp.start) + abs(ref.end - hyp.end) + 0.1
90
+ else:
91
+ if code == Code.match:
92
+ return 0
93
+ elif code == Code.insertion or code == Code.deletion:
94
+ return 3
95
+ else: # substitution
96
+ return 4
97
+
98
+ def get_result(self, refs, hyps):
99
+ res = AlignmentResult(refs=deque(), hyps=deque(), codes=deque(), score=np.nan)
100
+
101
+ num_rows, num_cols = self.scores_.shape
102
+ res.score = self.scores_[num_rows - 1, num_cols - 1]
103
+
104
+ curr_offset = coordinate_to_offset(num_rows - 1, num_cols - 1, num_cols)
105
+
106
+ while curr_offset != 0:
107
+ curr_row = offset_to_row(curr_offset, num_cols)
108
+ curr_col = offset_to_col(curr_offset, num_cols)
109
+
110
+ prev_offset = self.backtraces_[curr_row, curr_col]
111
+
112
+ prev_row = offset_to_row(prev_offset, num_cols)
113
+ prev_col = offset_to_col(prev_offset, num_cols)
114
+
115
+ res.refs.appendleft(curr_row - 1) # Note: this was .push_front() in C++
116
+ res.hyps.appendleft(curr_col - 1)
117
+ if curr_row - 1 == prev_row and curr_col == prev_col:
118
+ res.codes.appendleft(Code.deletion)
119
+ elif curr_row == prev_row and curr_col - 1 == prev_col:
120
+ res.codes.appendleft(Code.insertion)
121
+ else:
122
+ # assert(curr_row - 1 == prev_row and curr_col - 1 == prev_col)
123
+ ref_str = refs[res.refs[0]].label
124
+ hyp_str = hyps[res.hyps[0]].label
125
+
126
+ if ref_str == hyp_str:
127
+ res.codes.appendleft(Code.match)
128
+ else:
129
+ res.codes.appendleft(Code.substitution)
130
+
131
+ confusion_pair = "%s -> %s" % (ref_str, hyp_str)
132
+ if confusion_pair not in self.confusion_pairs_:
133
+ self.confusion_pairs_[confusion_pair] = 1
134
+ else:
135
+ self.confusion_pairs_[confusion_pair] += 1
136
+
137
+ curr_offset = prev_offset
138
+
139
+ return res
140
+
141
+ def align(self, refs, hyps):
142
+ if len(refs) == 0 and len(hyps) == 0:
143
+ return np.nan
144
+
145
+ # NOTE: we're not resetting the values in these matrices because every value
146
+ # will be overridden in the loop below. If this assumption doesn't hold,
147
+ # be sure to set all entries in self.scores_ and self.backtraces_ to 0.
148
+ self.scores_ = np.zeros((len(refs) + 1, len(hyps) + 1))
149
+ self.backtraces_ = np.zeros((len(refs) + 1, len(hyps) + 1))
150
+
151
+ num_rows, num_cols = self.scores_.shape
152
+
153
+ for i in range(num_rows):
154
+ for j in range(num_cols):
155
+ if i == 0 and j == 0:
156
+ self.scores_[i, j] = 0.0
157
+ self.backtraces_[i, j] = 0
158
+ continue
159
+
160
+ if i == 0:
161
+ self.scores_[i, j] = self.scores_[i, j - 1] + self.cost(
162
+ None, hyps[j - 1], Code.insertion
163
+ )
164
+ self.backtraces_[i, j] = coordinate_to_offset(i, j - 1, num_cols)
165
+ continue
166
+
167
+ if j == 0:
168
+ self.scores_[i, j] = self.scores_[i - 1, j] + self.cost(
169
+ refs[i - 1], None, Code.deletion
170
+ )
171
+ self.backtraces_[i, j] = coordinate_to_offset(i - 1, j, num_cols)
172
+ continue
173
+
174
+ # Below here both i and j are greater than 0
175
+ ref = refs[i - 1]
176
+ hyp = hyps[j - 1]
177
+ best_score = self.scores_[i - 1, j - 1] + (
178
+ self.cost(ref, hyp, Code.match)
179
+ if (ref.label == hyp.label)
180
+ else self.cost(ref, hyp, Code.substitution)
181
+ )
182
+
183
+ prev_row = i - 1
184
+ prev_col = j - 1
185
+ ins = self.scores_[i, j - 1] + self.cost(None, hyp, Code.insertion)
186
+ if ins < best_score:
187
+ best_score = ins
188
+ prev_row = i
189
+ prev_col = j - 1
190
+
191
+ delt = self.scores_[i - 1, j] + self.cost(ref, None, Code.deletion)
192
+ if delt < best_score:
193
+ best_score = delt
194
+ prev_row = i - 1
195
+ prev_col = j
196
+
197
+ self.scores_[i, j] = best_score
198
+ self.backtraces_[i, j] = coordinate_to_offset(
199
+ prev_row, prev_col, num_cols
200
+ )
201
+
202
+ return self.get_result(refs, hyps)
203
+
204
+
205
+ class WERTransformer(object):
206
+ def __init__(self, hyp_str, ref_str, verbose=True):
207
+ self.ed_ = EditDistance(False)
208
+ self.id2oracle_errs_ = {}
209
+ self.utts_ = 0
210
+ self.words_ = 0
211
+ self.insertions_ = 0
212
+ self.deletions_ = 0
213
+ self.substitutions_ = 0
214
+
215
+ self.process(["dummy_str", hyp_str, ref_str])
216
+
217
+ if verbose:
218
+ print("'%s' vs '%s'" % (hyp_str, ref_str))
219
+ self.report_result()
220
+
221
+ def process(self, input): # std::vector<std::string>&& input
222
+ if len(input) < 3:
223
+ print(
224
+ "Input must be of the form <id> ... <hypo> <ref> , got ",
225
+ len(input),
226
+ " inputs:",
227
+ )
228
+ return None
229
+
230
+ # Align
231
+ # std::vector<Token> hyps;
232
+ # std::vector<Token> refs;
233
+
234
+ hyps = str2toks(input[-2])
235
+ refs = str2toks(input[-1])
236
+
237
+ alignment = self.ed_.align(refs, hyps)
238
+ if alignment is None:
239
+ print("Alignment is null")
240
+ return np.nan
241
+
242
+ # Tally errors
243
+ ins = 0
244
+ dels = 0
245
+ subs = 0
246
+ for code in alignment.codes:
247
+ if code == Code.substitution:
248
+ subs += 1
249
+ elif code == Code.insertion:
250
+ ins += 1
251
+ elif code == Code.deletion:
252
+ dels += 1
253
+
254
+ # Output
255
+ row = input
256
+ row.append(str(len(refs)))
257
+ row.append(str(ins))
258
+ row.append(str(dels))
259
+ row.append(str(subs))
260
+ # print(row)
261
+
262
+ # Accumulate
263
+ kIdIndex = 0
264
+ kNBestSep = "/"
265
+
266
+ pieces = input[kIdIndex].split(kNBestSep)
267
+
268
+ if len(pieces) == 0:
269
+ print(
270
+ "Error splitting ",
271
+ input[kIdIndex],
272
+ " on '",
273
+ kNBestSep,
274
+ "', got empty list",
275
+ )
276
+ return np.nan
277
+
278
+ id = pieces[0]
279
+ if id not in self.id2oracle_errs_:
280
+ self.utts_ += 1
281
+ self.words_ += len(refs)
282
+ self.insertions_ += ins
283
+ self.deletions_ += dels
284
+ self.substitutions_ += subs
285
+ self.id2oracle_errs_[id] = [ins, dels, subs]
286
+ else:
287
+ curr_err = ins + dels + subs
288
+ prev_err = np.sum(self.id2oracle_errs_[id])
289
+ if curr_err < prev_err:
290
+ self.id2oracle_errs_[id] = [ins, dels, subs]
291
+
292
+ return 0
293
+
294
+ def report_result(self):
295
+ # print("---------- Summary ---------------")
296
+ if self.words_ == 0:
297
+ print("No words counted")
298
+ return
299
+
300
+ # 1-best
301
+ best_wer = (
302
+ 100.0
303
+ * (self.insertions_ + self.deletions_ + self.substitutions_)
304
+ / self.words_
305
+ )
306
+
307
+ print(
308
+ "\tWER = %0.2f%% (%i utts, %i words, %0.2f%% ins, "
309
+ "%0.2f%% dels, %0.2f%% subs)"
310
+ % (
311
+ best_wer,
312
+ self.utts_,
313
+ self.words_,
314
+ 100.0 * self.insertions_ / self.words_,
315
+ 100.0 * self.deletions_ / self.words_,
316
+ 100.0 * self.substitutions_ / self.words_,
317
+ )
318
+ )
319
+
320
+ def wer(self):
321
+ if self.words_ == 0:
322
+ wer = np.nan
323
+ else:
324
+ wer = (
325
+ 100.0
326
+ * (self.insertions_ + self.deletions_ + self.substitutions_)
327
+ / self.words_
328
+ )
329
+ return wer
330
+
331
+ def stats(self):
332
+ if self.words_ == 0:
333
+ stats = {}
334
+ else:
335
+ wer = (
336
+ 100.0
337
+ * (self.insertions_ + self.deletions_ + self.substitutions_)
338
+ / self.words_
339
+ )
340
+ stats = dict(
341
+ {
342
+ "wer": wer,
343
+ "utts": self.utts_,
344
+ "numwords": self.words_,
345
+ "ins": self.insertions_,
346
+ "dels": self.deletions_,
347
+ "subs": self.substitutions_,
348
+ "confusion_pairs": self.ed_.confusion_pairs_,
349
+ }
350
+ )
351
+ return stats
352
+
353
+
354
+ def calc_wer(hyp_str, ref_str):
355
+ t = WERTransformer(hyp_str, ref_str, verbose=0)
356
+ return t.wer()
357
+
358
+
359
+ def calc_wer_stats(hyp_str, ref_str):
360
+ t = WERTransformer(hyp_str, ref_str, verbose=0)
361
+ return t.stats()
362
+
363
+
364
+ def get_wer_alignment_codes(hyp_str, ref_str):
365
+ """
366
+ INPUT: hypothesis string, reference string
367
+ OUTPUT: List of alignment codes (intermediate results from WER computation)
368
+ """
369
+ t = WERTransformer(hyp_str, ref_str, verbose=0)
370
+ return t.ed_.align(str2toks(ref_str), str2toks(hyp_str)).codes
371
+
372
+
373
+ def merge_counts(x, y):
374
+ # Merge two hashes which have 'counts' as their values
375
+ # This can be used for example to merge confusion pair counts
376
+ # conf_pairs = merge_counts(conf_pairs, stats['confusion_pairs'])
377
+ for k, v in y.items():
378
+ if k not in x:
379
+ x[k] = 0
380
+ x[k] += v
381
+ return x
fairseq/examples/speech_recognition/w2l_decoder.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ """
9
+ Flashlight decoders.
10
+ """
11
+
12
+ import gc
13
+ import itertools as it
14
+ import os.path as osp
15
+ from typing import List
16
+ import warnings
17
+ from collections import deque, namedtuple
18
+
19
+ import numpy as np
20
+ import torch
21
+ from examples.speech_recognition.data.replabels import unpack_replabels
22
+ from fairseq import tasks
23
+ from fairseq.utils import apply_to_sample
24
+ from omegaconf import open_dict
25
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
26
+
27
+
28
+ try:
29
+ from flashlight.lib.text.dictionary import create_word_dict, load_words
30
+ from flashlight.lib.sequence.criterion import CpuViterbiPath, get_data_ptr_as_bytes
31
+ from flashlight.lib.text.decoder import (
32
+ CriterionType,
33
+ LexiconDecoderOptions,
34
+ KenLM,
35
+ LM,
36
+ LMState,
37
+ SmearingMode,
38
+ Trie,
39
+ LexiconDecoder,
40
+ )
41
+ except:
42
+ warnings.warn(
43
+ "flashlight python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/flashlight/tree/master/bindings/python"
44
+ )
45
+ LM = object
46
+ LMState = object
47
+
48
+
49
+ class W2lDecoder(object):
50
+ def __init__(self, args, tgt_dict):
51
+ self.tgt_dict = tgt_dict
52
+ self.vocab_size = len(tgt_dict)
53
+ self.nbest = args.nbest
54
+
55
+ # criterion-specific init
56
+ self.criterion_type = CriterionType.CTC
57
+ self.blank = (
58
+ tgt_dict.index("<ctc_blank>")
59
+ if "<ctc_blank>" in tgt_dict.indices
60
+ else tgt_dict.bos()
61
+ )
62
+ if "<sep>" in tgt_dict.indices:
63
+ self.silence = tgt_dict.index("<sep>")
64
+ elif "|" in tgt_dict.indices:
65
+ self.silence = tgt_dict.index("|")
66
+ else:
67
+ self.silence = tgt_dict.eos()
68
+ self.asg_transitions = None
69
+
70
+ def generate(self, models, sample, **unused):
71
+ """Generate a batch of inferences."""
72
+ # model.forward normally channels prev_output_tokens into the decoder
73
+ # separately, but SequenceGenerator directly calls model.encoder
74
+ encoder_input = {
75
+ k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
76
+ }
77
+ emissions = self.get_emissions(models, encoder_input)
78
+ return self.decode(emissions)
79
+
80
+ def get_emissions(self, models, encoder_input):
81
+ """Run encoder and normalize emissions"""
82
+ model = models[0]
83
+ encoder_out = model(**encoder_input)
84
+ if hasattr(model, "get_logits"):
85
+ emissions = model.get_logits(encoder_out) # no need to normalize emissions
86
+ else:
87
+ emissions = model.get_normalized_probs(encoder_out, log_probs=True)
88
+ return emissions.transpose(0, 1).float().cpu().contiguous()
89
+
90
+ def get_tokens(self, idxs):
91
+ """Normalize tokens by handling CTC blank, ASG replabels, etc."""
92
+ idxs = (g[0] for g in it.groupby(idxs))
93
+ idxs = filter(lambda x: x != self.blank, idxs)
94
+ return torch.LongTensor(list(idxs))
95
+
96
+
97
+ class W2lViterbiDecoder(W2lDecoder):
98
+ def __init__(self, args, tgt_dict):
99
+ super().__init__(args, tgt_dict)
100
+
101
+ def decode(self, emissions):
102
+ B, T, N = emissions.size()
103
+ hypos = []
104
+ if self.asg_transitions is None:
105
+ transitions = torch.FloatTensor(N, N).zero_()
106
+ else:
107
+ transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
108
+ viterbi_path = torch.IntTensor(B, T)
109
+ workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
110
+ CpuViterbiPath.compute(
111
+ B,
112
+ T,
113
+ N,
114
+ get_data_ptr_as_bytes(emissions),
115
+ get_data_ptr_as_bytes(transitions),
116
+ get_data_ptr_as_bytes(viterbi_path),
117
+ get_data_ptr_as_bytes(workspace),
118
+ )
119
+ return [
120
+ [{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}]
121
+ for b in range(B)
122
+ ]
123
+
124
+
125
+ class W2lKenLMDecoder(W2lDecoder):
126
+ def __init__(self, args, tgt_dict):
127
+ super().__init__(args, tgt_dict)
128
+
129
+ self.unit_lm = getattr(args, "unit_lm", False)
130
+
131
+ if args.lexicon:
132
+ self.lexicon = load_words(args.lexicon)
133
+ self.word_dict = create_word_dict(self.lexicon)
134
+ self.unk_word = self.word_dict.get_index("<unk>")
135
+
136
+ self.lm = KenLM(args.kenlm_model, self.word_dict)
137
+ self.trie = Trie(self.vocab_size, self.silence)
138
+
139
+ start_state = self.lm.start(False)
140
+ for i, (word, spellings) in enumerate(self.lexicon.items()):
141
+ word_idx = self.word_dict.get_index(word)
142
+ _, score = self.lm.score(start_state, word_idx)
143
+ for spelling in spellings:
144
+ spelling_idxs = [tgt_dict.index(token) for token in spelling]
145
+ assert (
146
+ tgt_dict.unk() not in spelling_idxs
147
+ ), f"{spelling} {spelling_idxs}"
148
+ self.trie.insert(spelling_idxs, word_idx, score)
149
+ self.trie.smear(SmearingMode.MAX)
150
+
151
+ self.decoder_opts = LexiconDecoderOptions(
152
+ beam_size=args.beam,
153
+ beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))),
154
+ beam_threshold=args.beam_threshold,
155
+ lm_weight=args.lm_weight,
156
+ word_score=args.word_score,
157
+ unk_score=args.unk_weight,
158
+ sil_score=args.sil_weight,
159
+ log_add=False,
160
+ criterion_type=self.criterion_type,
161
+ )
162
+
163
+ if self.asg_transitions is None:
164
+ N = 768
165
+ # self.asg_transitions = torch.FloatTensor(N, N).zero_()
166
+ self.asg_transitions = []
167
+
168
+ self.decoder = LexiconDecoder(
169
+ self.decoder_opts,
170
+ self.trie,
171
+ self.lm,
172
+ self.silence,
173
+ self.blank,
174
+ self.unk_word,
175
+ self.asg_transitions,
176
+ self.unit_lm,
177
+ )
178
+ else:
179
+ assert args.unit_lm, "lexicon free decoding can only be done with a unit language model"
180
+ from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions
181
+
182
+ d = {w: [[w]] for w in tgt_dict.symbols}
183
+ self.word_dict = create_word_dict(d)
184
+ self.lm = KenLM(args.kenlm_model, self.word_dict)
185
+ self.decoder_opts = LexiconFreeDecoderOptions(
186
+ beam_size=args.beam,
187
+ beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))),
188
+ beam_threshold=args.beam_threshold,
189
+ lm_weight=args.lm_weight,
190
+ sil_score=args.sil_weight,
191
+ log_add=False,
192
+ criterion_type=self.criterion_type,
193
+ )
194
+ self.decoder = LexiconFreeDecoder(
195
+ self.decoder_opts, self.lm, self.silence, self.blank, []
196
+ )
197
+
198
+ def get_timesteps(self, token_idxs: List[int]) -> List[int]:
199
+ """Returns frame numbers corresponding to every non-blank token.
200
+
201
+ Parameters
202
+ ----------
203
+ token_idxs : List[int]
204
+ IDs of decoded tokens.
205
+
206
+ Returns
207
+ -------
208
+ List[int]
209
+ Frame numbers corresponding to every non-blank token.
210
+ """
211
+ timesteps = []
212
+ for i, token_idx in enumerate(token_idxs):
213
+ if token_idx == self.blank:
214
+ continue
215
+ if i == 0 or token_idx != token_idxs[i-1]:
216
+ timesteps.append(i)
217
+ return timesteps
218
+
219
+ def decode(self, emissions):
220
+ B, T, N = emissions.size()
221
+ hypos = []
222
+ for b in range(B):
223
+ emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
224
+ results = self.decoder.decode(emissions_ptr, T, N)
225
+
226
+ nbest_results = results[: self.nbest]
227
+ hypos.append(
228
+ [
229
+ {
230
+ "tokens": self.get_tokens(result.tokens),
231
+ "score": result.score,
232
+ "timesteps": self.get_timesteps(result.tokens),
233
+ "words": [
234
+ self.word_dict.get_entry(x) for x in result.words if x >= 0
235
+ ],
236
+ }
237
+ for result in nbest_results
238
+ ]
239
+ )
240
+ return hypos
241
+
242
+
243
+ FairseqLMState = namedtuple("FairseqLMState", ["prefix", "incremental_state", "probs"])
244
+
245
+
246
+ class FairseqLM(LM):
247
+ def __init__(self, dictionary, model):
248
+ LM.__init__(self)
249
+ self.dictionary = dictionary
250
+ self.model = model
251
+ self.unk = self.dictionary.unk()
252
+
253
+ self.save_incremental = False # this currently does not work properly
254
+ self.max_cache = 20_000
255
+
256
+ model.cuda()
257
+ model.eval()
258
+ model.make_generation_fast_()
259
+
260
+ self.states = {}
261
+ self.stateq = deque()
262
+
263
+ def start(self, start_with_nothing):
264
+ state = LMState()
265
+ prefix = torch.LongTensor([[self.dictionary.eos()]])
266
+ incremental_state = {} if self.save_incremental else None
267
+ with torch.no_grad():
268
+ res = self.model(prefix.cuda(), incremental_state=incremental_state)
269
+ probs = self.model.get_normalized_probs(res, log_probs=True, sample=None)
270
+
271
+ if incremental_state is not None:
272
+ incremental_state = apply_to_sample(lambda x: x.cpu(), incremental_state)
273
+ self.states[state] = FairseqLMState(
274
+ prefix.numpy(), incremental_state, probs[0, -1].cpu().numpy()
275
+ )
276
+ self.stateq.append(state)
277
+
278
+ return state
279
+
280
+ def score(self, state: LMState, token_index: int, no_cache: bool = False):
281
+ """
282
+ Evaluate language model based on the current lm state and new word
283
+ Parameters:
284
+ -----------
285
+ state: current lm state
286
+ token_index: index of the word
287
+ (can be lexicon index then you should store inside LM the
288
+ mapping between indices of lexicon and lm, or lm index of a word)
289
+
290
+ Returns:
291
+ --------
292
+ (LMState, float): pair of (new state, score for the current word)
293
+ """
294
+ curr_state = self.states[state]
295
+
296
+ def trim_cache(targ_size):
297
+ while len(self.stateq) > targ_size:
298
+ rem_k = self.stateq.popleft()
299
+ rem_st = self.states[rem_k]
300
+ rem_st = FairseqLMState(rem_st.prefix, None, None)
301
+ self.states[rem_k] = rem_st
302
+
303
+ if curr_state.probs is None:
304
+ new_incremental_state = (
305
+ curr_state.incremental_state.copy()
306
+ if curr_state.incremental_state is not None
307
+ else None
308
+ )
309
+ with torch.no_grad():
310
+ if new_incremental_state is not None:
311
+ new_incremental_state = apply_to_sample(
312
+ lambda x: x.cuda(), new_incremental_state
313
+ )
314
+ elif self.save_incremental:
315
+ new_incremental_state = {}
316
+
317
+ res = self.model(
318
+ torch.from_numpy(curr_state.prefix).cuda(),
319
+ incremental_state=new_incremental_state,
320
+ )
321
+ probs = self.model.get_normalized_probs(
322
+ res, log_probs=True, sample=None
323
+ )
324
+
325
+ if new_incremental_state is not None:
326
+ new_incremental_state = apply_to_sample(
327
+ lambda x: x.cpu(), new_incremental_state
328
+ )
329
+
330
+ curr_state = FairseqLMState(
331
+ curr_state.prefix, new_incremental_state, probs[0, -1].cpu().numpy()
332
+ )
333
+
334
+ if not no_cache:
335
+ self.states[state] = curr_state
336
+ self.stateq.append(state)
337
+
338
+ score = curr_state.probs[token_index].item()
339
+
340
+ trim_cache(self.max_cache)
341
+
342
+ outstate = state.child(token_index)
343
+ if outstate not in self.states and not no_cache:
344
+ prefix = np.concatenate(
345
+ [curr_state.prefix, torch.LongTensor([[token_index]])], -1
346
+ )
347
+ incr_state = curr_state.incremental_state
348
+
349
+ self.states[outstate] = FairseqLMState(prefix, incr_state, None)
350
+
351
+ if token_index == self.unk:
352
+ score = float("-inf")
353
+
354
+ return outstate, score
355
+
356
+ def finish(self, state: LMState):
357
+ """
358
+ Evaluate eos for language model based on the current lm state
359
+
360
+ Returns:
361
+ --------
362
+ (LMState, float): pair of (new state, score for the current word)
363
+ """
364
+ return self.score(state, self.dictionary.eos())
365
+
366
+ def empty_cache(self):
367
+ self.states = {}
368
+ self.stateq = deque()
369
+ gc.collect()
370
+
371
+
372
+ class W2lFairseqLMDecoder(W2lDecoder):
373
+ def __init__(self, args, tgt_dict):
374
+ super().__init__(args, tgt_dict)
375
+
376
+ self.unit_lm = getattr(args, "unit_lm", False)
377
+
378
+ self.lexicon = load_words(args.lexicon) if args.lexicon else None
379
+ self.idx_to_wrd = {}
380
+
381
+ checkpoint = torch.load(args.kenlm_model, map_location="cpu")
382
+
383
+ if "cfg" in checkpoint and checkpoint["cfg"] is not None:
384
+ lm_args = checkpoint["cfg"]
385
+ else:
386
+ lm_args = convert_namespace_to_omegaconf(checkpoint["args"])
387
+
388
+ with open_dict(lm_args.task):
389
+ lm_args.task.data = osp.dirname(args.kenlm_model)
390
+
391
+ task = tasks.setup_task(lm_args.task)
392
+ model = task.build_model(lm_args.model)
393
+ model.load_state_dict(checkpoint["model"], strict=False)
394
+
395
+ self.trie = Trie(self.vocab_size, self.silence)
396
+
397
+ self.word_dict = task.dictionary
398
+ self.unk_word = self.word_dict.unk()
399
+ self.lm = FairseqLM(self.word_dict, model)
400
+
401
+ if self.lexicon:
402
+ start_state = self.lm.start(False)
403
+ for i, (word, spellings) in enumerate(self.lexicon.items()):
404
+ if self.unit_lm:
405
+ word_idx = i
406
+ self.idx_to_wrd[i] = word
407
+ score = 0
408
+ else:
409
+ word_idx = self.word_dict.index(word)
410
+ _, score = self.lm.score(start_state, word_idx, no_cache=True)
411
+
412
+ for spelling in spellings:
413
+ spelling_idxs = [tgt_dict.index(token) for token in spelling]
414
+ assert (
415
+ tgt_dict.unk() not in spelling_idxs
416
+ ), f"{spelling} {spelling_idxs}"
417
+ self.trie.insert(spelling_idxs, word_idx, score)
418
+ self.trie.smear(SmearingMode.MAX)
419
+
420
+ self.decoder_opts = LexiconDecoderOptions(
421
+ beam_size=args.beam,
422
+ beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))),
423
+ beam_threshold=args.beam_threshold,
424
+ lm_weight=args.lm_weight,
425
+ word_score=args.word_score,
426
+ unk_score=args.unk_weight,
427
+ sil_score=args.sil_weight,
428
+ log_add=False,
429
+ criterion_type=self.criterion_type,
430
+ )
431
+
432
+ self.decoder = LexiconDecoder(
433
+ self.decoder_opts,
434
+ self.trie,
435
+ self.lm,
436
+ self.silence,
437
+ self.blank,
438
+ self.unk_word,
439
+ [],
440
+ self.unit_lm,
441
+ )
442
+ else:
443
+ assert args.unit_lm, "lexicon free decoding can only be done with a unit language model"
444
+ from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions
445
+
446
+ d = {w: [[w]] for w in tgt_dict.symbols}
447
+ self.word_dict = create_word_dict(d)
448
+ self.lm = KenLM(args.kenlm_model, self.word_dict)
449
+ self.decoder_opts = LexiconFreeDecoderOptions(
450
+ beam_size=args.beam,
451
+ beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))),
452
+ beam_threshold=args.beam_threshold,
453
+ lm_weight=args.lm_weight,
454
+ sil_score=args.sil_weight,
455
+ log_add=False,
456
+ criterion_type=self.criterion_type,
457
+ )
458
+ self.decoder = LexiconFreeDecoder(
459
+ self.decoder_opts, self.lm, self.silence, self.blank, []
460
+ )
461
+
462
+ def decode(self, emissions):
463
+ B, T, N = emissions.size()
464
+ hypos = []
465
+
466
+ def idx_to_word(idx):
467
+ if self.unit_lm:
468
+ return self.idx_to_wrd[idx]
469
+ else:
470
+ return self.word_dict[idx]
471
+
472
+ def make_hypo(result):
473
+ hypo = {"tokens": self.get_tokens(result.tokens), "score": result.score}
474
+ if self.lexicon:
475
+ hypo["words"] = [idx_to_word(x) for x in result.words if x >= 0]
476
+ return hypo
477
+
478
+ for b in range(B):
479
+ emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
480
+ results = self.decoder.decode(emissions_ptr, T, N)
481
+
482
+ nbest_results = results[: self.nbest]
483
+ hypos.append([make_hypo(result) for result in nbest_results])
484
+ self.lm.empty_cache()
485
+
486
+ return hypos
fairseq/examples/speech_synthesis/README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Speech Synthesis (S^2)
2
+ ===
3
+ [https://arxiv.org/abs/2109.06912](https://arxiv.org/abs/2109.06912)
4
+
5
+ Speech synthesis with fairseq.
6
+
7
+ ## Features
8
+
9
+ - Autoregressive and non-autoregressive models
10
+ - Multi-speaker synthesis
11
+ - Audio preprocessing (denoising, VAD, etc.) for less curated data
12
+ - Automatic metrics for model development
13
+ - Similar data configuration as [S2T](../speech_to_text/README.md)
14
+
15
+
16
+ ## Examples
17
+ - [Single-speaker synthesis on LJSpeech](docs/ljspeech_example.md)
18
+ - [Multi-speaker synthesis on VCTK](docs/vctk_example.md)
19
+ - [Multi-speaker synthesis on Common Voice](docs/common_voice_example.md)
20
+
21
+
22
+ ## Citation
23
+ Please cite as:
24
+ ```
25
+ @article{wang2021fairseqs2,
26
+ title={fairseq S\^{} 2: A Scalable and Integrable Speech Synthesis Toolkit},
27
+ author={Wang, Changhan and Hsu, Wei-Ning and Adi, Yossi and Polyak, Adam and Lee, Ann and Chen, Peng-Jen and Gu, Jiatao and Pino, Juan},
28
+ journal={arXiv preprint arXiv:2109.06912},
29
+ year={2021}
30
+ }
31
+
32
+ @inproceedings{ott2019fairseq,
33
+ title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
34
+ author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
35
+ booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
36
+ year = {2019},
37
+ }
38
+ ```
fairseq/examples/speech_synthesis/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
fairseq/examples/speech_synthesis/data_utils.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import io
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Optional, List, Dict
10
+ import zipfile
11
+ import tempfile
12
+ from dataclasses import dataclass
13
+ from itertools import groupby
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import numpy as np
18
+ from tqdm import tqdm
19
+
20
+ from examples.speech_to_text.data_utils import load_tsv_to_dicts
21
+ from fairseq.data.audio.audio_utils import (
22
+ TTSSpectrogram, TTSMelScale, parse_path, read_from_stored_zip, is_npy_data
23
+ )
24
+
25
+
26
+ def trim_or_pad_to_target_length(
27
+ data_1d_or_2d: np.ndarray, target_length: int
28
+ ) -> np.ndarray:
29
+ assert len(data_1d_or_2d.shape) in {1, 2}
30
+ delta = data_1d_or_2d.shape[0] - target_length
31
+ if delta >= 0: # trim if being longer
32
+ data_1d_or_2d = data_1d_or_2d[: target_length]
33
+ else: # pad if being shorter
34
+ if len(data_1d_or_2d.shape) == 1:
35
+ data_1d_or_2d = np.concatenate(
36
+ [data_1d_or_2d, np.zeros(-delta)], axis=0
37
+ )
38
+ else:
39
+ data_1d_or_2d = np.concatenate(
40
+ [data_1d_or_2d, np.zeros((-delta, data_1d_or_2d.shape[1]))],
41
+ axis=0
42
+ )
43
+ return data_1d_or_2d
44
+
45
+
46
+ def extract_logmel_spectrogram(
47
+ waveform: torch.Tensor, sample_rate: int,
48
+ output_path: Optional[Path] = None, win_length: int = 1024,
49
+ hop_length: int = 256, n_fft: int = 1024,
50
+ win_fn: callable = torch.hann_window, n_mels: int = 80,
51
+ f_min: float = 0., f_max: float = 8000, eps: float = 1e-5,
52
+ overwrite: bool = False, target_length: Optional[int] = None
53
+ ):
54
+ if output_path is not None and output_path.is_file() and not overwrite:
55
+ return
56
+
57
+ spectrogram_transform = TTSSpectrogram(
58
+ n_fft=n_fft, win_length=win_length, hop_length=hop_length,
59
+ window_fn=win_fn
60
+ )
61
+ mel_scale_transform = TTSMelScale(
62
+ n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
63
+ n_stft=n_fft // 2 + 1
64
+ )
65
+ spectrogram = spectrogram_transform(waveform)
66
+ mel_spec = mel_scale_transform(spectrogram)
67
+ logmel_spec = torch.clamp(mel_spec, min=eps).log()
68
+ assert len(logmel_spec.shape) == 3 and logmel_spec.shape[0] == 1
69
+ logmel_spec = logmel_spec.squeeze().t() # D x T -> T x D
70
+ if target_length is not None:
71
+ logmel_spec = trim_or_pad_to_target_length(logmel_spec, target_length)
72
+
73
+ if output_path is not None:
74
+ np.save(output_path.as_posix(), logmel_spec)
75
+ else:
76
+ return logmel_spec
77
+
78
+
79
+ def extract_pitch(
80
+ waveform: torch.Tensor, sample_rate: int,
81
+ output_path: Optional[Path] = None, hop_length: int = 256,
82
+ log_scale: bool = True, phoneme_durations: Optional[List[int]] = None
83
+ ):
84
+ if output_path is not None and output_path.is_file():
85
+ return
86
+
87
+ try:
88
+ import pyworld
89
+ except ImportError:
90
+ raise ImportError("Please install PyWORLD: pip install pyworld")
91
+
92
+ _waveform = waveform.squeeze(0).double().numpy()
93
+ pitch, t = pyworld.dio(
94
+ _waveform, sample_rate, frame_period=hop_length / sample_rate * 1000
95
+ )
96
+ pitch = pyworld.stonemask(_waveform, pitch, t, sample_rate)
97
+
98
+ if phoneme_durations is not None:
99
+ pitch = trim_or_pad_to_target_length(pitch, sum(phoneme_durations))
100
+ try:
101
+ from scipy.interpolate import interp1d
102
+ except ImportError:
103
+ raise ImportError("Please install SciPy: pip install scipy")
104
+ nonzero_ids = np.where(pitch != 0)[0]
105
+ if len(nonzero_ids) == 0:
106
+ print((f"{output_path} has all empty values in the pitch contour"))
107
+ return
108
+ elif len(nonzero_ids) == 1:
109
+ print((f"{output_path} has only one non-zero values in the pitch contour"))
110
+ return
111
+ else:
112
+ interp_fn = interp1d(
113
+ nonzero_ids,
114
+ pitch[nonzero_ids],
115
+ fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]),
116
+ bounds_error=False,
117
+ )
118
+ pitch = interp_fn(np.arange(0, len(pitch)))
119
+ d_cumsum = np.cumsum(np.concatenate([np.array([0]), phoneme_durations]))
120
+ pitch = np.array(
121
+ [
122
+ np.mean(pitch[d_cumsum[i-1]: d_cumsum[i]])
123
+ for i in range(1, len(d_cumsum))
124
+ ]
125
+ )
126
+ assert len(pitch) == len(phoneme_durations)
127
+
128
+ if log_scale:
129
+ pitch = np.log(pitch + 1)
130
+
131
+ if output_path is not None:
132
+ np.save(output_path.as_posix(), pitch)
133
+ else:
134
+ return pitch
135
+
136
+
137
+ def extract_energy(
138
+ waveform: torch.Tensor, output_path: Optional[Path] = None,
139
+ hop_length: int = 256, n_fft: int = 1024, log_scale: bool = True,
140
+ phoneme_durations: Optional[List[int]] = None
141
+ ):
142
+ if output_path is not None and output_path.is_file():
143
+ return
144
+
145
+ assert len(waveform.shape) == 2 and waveform.shape[0] == 1
146
+ waveform = waveform.view(1, 1, waveform.shape[1])
147
+ waveform = F.pad(
148
+ waveform.unsqueeze(1), [n_fft // 2, n_fft // 2, 0, 0],
149
+ mode="reflect"
150
+ )
151
+ waveform = waveform.squeeze(1)
152
+
153
+ fourier_basis = np.fft.fft(np.eye(n_fft))
154
+ cutoff = int((n_fft / 2 + 1))
155
+ fourier_basis = np.vstack(
156
+ [np.real(fourier_basis[:cutoff, :]),
157
+ np.imag(fourier_basis[:cutoff, :])]
158
+ )
159
+
160
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
161
+ forward_transform = F.conv1d(
162
+ waveform, forward_basis, stride=hop_length, padding=0
163
+ )
164
+
165
+ real_part = forward_transform[:, :cutoff, :]
166
+ imag_part = forward_transform[:, cutoff:, :]
167
+ magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
168
+ energy = torch.norm(magnitude, dim=1).squeeze(0).numpy()
169
+
170
+ if phoneme_durations is not None:
171
+ energy = trim_or_pad_to_target_length(energy, sum(phoneme_durations))
172
+ d_cumsum = np.cumsum(np.concatenate([np.array([0]), phoneme_durations]))
173
+ energy = np.array(
174
+ [
175
+ np.mean(energy[d_cumsum[i - 1]: d_cumsum[i]])
176
+ for i in range(1, len(d_cumsum))
177
+ ]
178
+ )
179
+ assert len(energy) == len(phoneme_durations)
180
+
181
+ if log_scale:
182
+ energy = np.log(energy + 1)
183
+
184
+ if output_path is not None:
185
+ np.save(output_path.as_posix(), energy)
186
+ else:
187
+ return energy
188
+
189
+
190
+ def get_global_cmvn(feature_root: Path, output_path: Optional[Path] = None):
191
+ mean_x, mean_x2, n_frames = None, None, 0
192
+ feature_paths = feature_root.glob("*.npy")
193
+ for p in tqdm(feature_paths):
194
+ with open(p, 'rb') as f:
195
+ frames = np.load(f).squeeze()
196
+
197
+ n_frames += frames.shape[0]
198
+
199
+ cur_mean_x = frames.sum(axis=0)
200
+ if mean_x is None:
201
+ mean_x = cur_mean_x
202
+ else:
203
+ mean_x += cur_mean_x
204
+
205
+ cur_mean_x2 = (frames ** 2).sum(axis=0)
206
+ if mean_x2 is None:
207
+ mean_x2 = cur_mean_x2
208
+ else:
209
+ mean_x2 += cur_mean_x2
210
+
211
+ mean_x /= n_frames
212
+ mean_x2 /= n_frames
213
+ var_x = mean_x2 - mean_x ** 2
214
+ std_x = np.sqrt(np.maximum(var_x, 1e-10))
215
+
216
+ if output_path is not None:
217
+ with open(output_path, 'wb') as f:
218
+ np.savez(f, mean=mean_x, std=std_x)
219
+ else:
220
+ return {"mean": mean_x, "std": std_x}
221
+
222
+
223
+ def ipa_phonemize(text, lang="en-us", use_g2p=False):
224
+ if use_g2p:
225
+ assert lang == "en-us", "g2pE phonemizer only works for en-us"
226
+ try:
227
+ from g2p_en import G2p
228
+ g2p = G2p()
229
+ return " ".join("|" if p == " " else p for p in g2p(text))
230
+ except ImportError:
231
+ raise ImportError(
232
+ "Please install phonemizer: pip install g2p_en"
233
+ )
234
+ else:
235
+ try:
236
+ from phonemizer import phonemize
237
+ from phonemizer.separator import Separator
238
+ return phonemize(
239
+ text, backend='espeak', language=lang,
240
+ separator=Separator(word="| ", phone=" ")
241
+ )
242
+ except ImportError:
243
+ raise ImportError(
244
+ "Please install phonemizer: pip install phonemizer"
245
+ )
246
+
247
+
248
+ @dataclass
249
+ class ForceAlignmentInfo(object):
250
+ tokens: List[str]
251
+ frame_durations: List[int]
252
+ start_sec: Optional[float]
253
+ end_sec: Optional[float]
254
+
255
+
256
+ def get_mfa_alignment_by_sample_id(
257
+ textgrid_zip_path: str, sample_id: str, sample_rate: int,
258
+ hop_length: int, silence_phones: List[str] = ("sil", "sp", "spn")
259
+ ) -> ForceAlignmentInfo:
260
+ try:
261
+ import tgt
262
+ except ImportError:
263
+ raise ImportError("Please install TextGridTools: pip install tgt")
264
+
265
+ filename = f"{sample_id}.TextGrid"
266
+ out_root = Path(tempfile.gettempdir())
267
+ tgt_path = out_root / filename
268
+ with zipfile.ZipFile(textgrid_zip_path) as f_zip:
269
+ f_zip.extract(filename, path=out_root)
270
+ textgrid = tgt.io.read_textgrid(tgt_path.as_posix())
271
+ os.remove(tgt_path)
272
+
273
+ phones, frame_durations = [], []
274
+ start_sec, end_sec, end_idx = 0, 0, 0
275
+ for t in textgrid.get_tier_by_name("phones")._objects:
276
+ s, e, p = t.start_time, t.end_time, t.text
277
+ # Trim leading silences
278
+ if len(phones) == 0:
279
+ if p in silence_phones:
280
+ continue
281
+ else:
282
+ start_sec = s
283
+ phones.append(p)
284
+ if p not in silence_phones:
285
+ end_sec = e
286
+ end_idx = len(phones)
287
+ r = sample_rate / hop_length
288
+ frame_durations.append(int(np.round(e * r) - np.round(s * r)))
289
+ # Trim tailing silences
290
+ phones = phones[:end_idx]
291
+ frame_durations = frame_durations[:end_idx]
292
+
293
+ return ForceAlignmentInfo(
294
+ tokens=phones, frame_durations=frame_durations, start_sec=start_sec,
295
+ end_sec=end_sec
296
+ )
297
+
298
+
299
+ def get_mfa_alignment(
300
+ textgrid_zip_path: str, sample_ids: List[str], sample_rate: int,
301
+ hop_length: int
302
+ ) -> Dict[str, ForceAlignmentInfo]:
303
+ return {
304
+ i: get_mfa_alignment_by_sample_id(
305
+ textgrid_zip_path, i, sample_rate, hop_length
306
+ ) for i in tqdm(sample_ids)
307
+ }
308
+
309
+
310
+ def get_unit_alignment(
311
+ id_to_unit_tsv_path: str, sample_ids: List[str]
312
+ ) -> Dict[str, ForceAlignmentInfo]:
313
+ id_to_units = {
314
+ e["id"]: e["units"] for e in load_tsv_to_dicts(id_to_unit_tsv_path)
315
+ }
316
+ id_to_units = {i: id_to_units[i].split() for i in sample_ids}
317
+ id_to_units_collapsed = {
318
+ i: [uu for uu, _ in groupby(u)] for i, u in id_to_units.items()
319
+ }
320
+ id_to_durations = {
321
+ i: [len(list(g)) for _, g in groupby(u)] for i, u in id_to_units.items()
322
+ }
323
+
324
+ return {
325
+ i: ForceAlignmentInfo(
326
+ tokens=id_to_units_collapsed[i], frame_durations=id_to_durations[i],
327
+ start_sec=None, end_sec=None
328
+ )
329
+ for i in sample_ids
330
+ }
331
+
332
+
333
+ def get_feature_value_min_max(feature_paths: List[str]):
334
+ v_min, v_max = 1e-8, -1e-8
335
+ for p in tqdm(feature_paths):
336
+ _path, slice_ptr = parse_path(p)
337
+ assert len(slice_ptr) == 2
338
+ byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
339
+ assert is_npy_data(byte_data)
340
+ path_or_fp = io.BytesIO(byte_data)
341
+ features = np.load(path_or_fp).squeeze()
342
+ v_min = min(v_min, features.min().item())
343
+ v_max = max(v_max, features.max().item())
344
+ return v_min, v_max
fairseq/examples/speech_synthesis/docs/common_voice_example.md ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [[Back]](..)
2
+
3
+ # Common Voice
4
+
5
+ [Common Voice](https://commonvoice.mozilla.org/en/datasets) is a public domain speech corpus with 11.2K hours of read
6
+ speech in 76 languages (the latest version 7.0). We provide examples for building
7
+ [Transformer](https://arxiv.org/abs/1809.08895) models on this dataset.
8
+
9
+
10
+ ## Data preparation
11
+ [Download](https://commonvoice.mozilla.org/en/datasets) and unpack Common Voice v4 to a path `${DATA_ROOT}/${LANG_ID}`.
12
+ Create splits and generate audio manifests with
13
+ ```bash
14
+ python -m examples.speech_synthesis.preprocessing.get_common_voice_audio_manifest \
15
+ --data-root ${DATA_ROOT} \
16
+ --lang ${LANG_ID} \
17
+ --output-manifest-root ${AUDIO_MANIFEST_ROOT} --convert-to-wav
18
+ ```
19
+
20
+ To denoise audio and trim leading/trailing silence using signal processing based VAD, run
21
+ ```bash
22
+ for SPLIT in dev test train; do
23
+ python -m examples.speech_synthesis.preprocessing.denoise_and_vad_audio \
24
+ --audio-manifest ${AUDIO_MANIFEST_ROOT}/${SPLIT}.audio.tsv \
25
+ --output-dir ${PROCESSED_DATA_ROOT} \
26
+ --denoise --vad --vad-agg-level 2
27
+ done
28
+ ```
29
+
30
+ which generates a new audio TSV manifest under `${PROCESSED_DATA_ROOT}` with updated path to the processed audio and
31
+ a new column for SNR.
32
+
33
+ To do filtering by CER, follow the [Automatic Evaluation](../docs/ljspeech_example.md#automatic-evaluation) section to
34
+ run ASR model (add `--eval-target` to `get_eval_manifest` for evaluation on the reference audio; add `--err-unit char`
35
+ to `eval_asr` to compute CER instead of WER). The example-level CER is saved to
36
+ `${EVAL_OUTPUT_ROOT}/uer_cer.${SPLIT}.tsv`.
37
+
38
+ Then, extract log-Mel spectrograms, generate feature manifest and create data configuration YAML with
39
+ ```bash
40
+ python -m examples.speech_synthesis.preprocessing.get_feature_manifest \
41
+ --audio-manifest-root ${AUDIO_MANIFEST_ROOT} \
42
+ --output-root ${FEATURE_MANIFEST_ROOT} \
43
+ --ipa-vocab --lang ${LANG_ID} \
44
+ --snr-threshold 15 \
45
+ --cer-threshold 0.1 --cer-tsv-path ${EVAL_OUTPUT_ROOT}/uer_cer.${SPLIT}.tsv
46
+ ```
47
+ where we use phoneme inputs (`--ipa-vocab`) as example. For sample filtering, we set the SNR and CER threshold
48
+ to 15 and 10%, respectively.
49
+
50
+
51
+ ## Training
52
+ (Please refer to [the LJSpeech example](../docs/ljspeech_example.md#transformer).)
53
+
54
+
55
+ ## Inference
56
+ (Please refer to [the LJSpeech example](../docs/ljspeech_example.md#inference).)
57
+
58
+ ## Automatic Evaluation
59
+ (Please refer to [the LJSpeech example](../docs/ljspeech_example.md#automatic-evaluation).)
60
+
61
+ ## Results
62
+
63
+ | Language | Speakers | --arch | Params | Test MCD | Model |
64
+ |---|---|---|---|---|---|
65
+ | English | 200 | tts_transformer | 54M | 3.8 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2/cv4_en200_transformer_phn.tar) |
66
+
67
+ [[Back]](..)
fairseq/examples/speech_synthesis/docs/ljspeech_example.md ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [[Back]](..)
2
+
3
+ # LJSpeech
4
+
5
+ [LJSpeech](https://keithito.com/LJ-Speech-Dataset) is a public domain TTS
6
+ corpus with around 24 hours of English speech sampled at 22.05kHz. We provide examples for building
7
+ [Transformer](https://arxiv.org/abs/1809.08895) and [FastSpeech 2](https://arxiv.org/abs/2006.04558)
8
+ models on this dataset.
9
+
10
+
11
+ ## Data preparation
12
+
13
+ Download data, create splits and generate audio manifests with
14
+ ```bash
15
+ python -m examples.speech_synthesis.preprocessing.get_ljspeech_audio_manifest \
16
+ --output-data-root ${AUDIO_DATA_ROOT} \
17
+ --output-manifest-root ${AUDIO_MANIFEST_ROOT}
18
+ ```
19
+
20
+ Then, extract log-Mel spectrograms, generate feature manifest and create data configuration YAML with
21
+ ```bash
22
+ python -m examples.speech_synthesis.preprocessing.get_feature_manifest \
23
+ --audio-manifest-root ${AUDIO_MANIFEST_ROOT} \
24
+ --output-root ${FEATURE_MANIFEST_ROOT} \
25
+ --ipa-vocab --use-g2p
26
+ ```
27
+ where we use phoneme inputs (`--ipa-vocab --use-g2p`) as example.
28
+
29
+ FastSpeech 2 additionally requires frame durations, pitch and energy as auxiliary training targets.
30
+ Add `--add-fastspeech-targets` to include these fields in the feature manifests. We get frame durations either from
31
+ phoneme-level force-alignment or frame-level pseudo-text unit sequence. They should be pre-computed and specified via:
32
+ - `--textgrid-zip ${TEXT_GRID_ZIP_PATH}` for a ZIP file, inside which there is one
33
+ [TextGrid](https://www.fon.hum.uva.nl/praat/manual/TextGrid.html) file per sample to provide force-alignment info.
34
+ - `--id-to-units-tsv ${ID_TO_UNIT_TSV}` for a TSV file, where there are 2 columns for sample ID and
35
+ space-delimited pseudo-text unit sequence, respectively.
36
+
37
+ For your convenience, we provide pre-computed
38
+ [force-alignment](https://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_mfa.zip) from
39
+ [Montreal Forced Aligner](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) and
40
+ [pseudo-text units](s3://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_hubert.tsv) from
41
+ [HuBERT](https://github.com/pytorch/fairseq/tree/main/examples/hubert). You can also generate them by yourself using
42
+ a different software or model.
43
+
44
+
45
+ ## Training
46
+ #### Transformer
47
+ ```bash
48
+ fairseq-train ${FEATURE_MANIFEST_ROOT} --save-dir ${SAVE_DIR} \
49
+ --config-yaml config.yaml --train-subset train --valid-subset dev \
50
+ --num-workers 4 --max-tokens 30000 --max-update 200000 \
51
+ --task text_to_speech --criterion tacotron2 --arch tts_transformer \
52
+ --clip-norm 5.0 --n-frames-per-step 4 --bce-pos-weight 5.0 \
53
+ --dropout 0.1 --attention-dropout 0.1 --activation-dropout 0.1 \
54
+ --encoder-normalize-before --decoder-normalize-before \
55
+ --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
56
+ --seed 1 --update-freq 8 --eval-inference --best-checkpoint-metric mcd_loss
57
+ ```
58
+ where `SAVE_DIR` is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to
59
+ update it accordingly when using more than 1 GPU.
60
+
61
+ #### FastSpeech2
62
+ ```bash
63
+ fairseq-train ${FEATURE_MANIFEST_ROOT} --save-dir ${SAVE_DIR} \
64
+ --config-yaml config.yaml --train-subset train --valid-subset dev \
65
+ --num-workers 4 --max-sentences 6 --max-update 200000 \
66
+ --task text_to_speech --criterion fastspeech2 --arch fastspeech2 \
67
+ --clip-norm 5.0 --n-frames-per-step 1 \
68
+ --dropout 0.1 --attention-dropout 0.1 \
69
+ --optimizer adam --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
70
+ --seed 1 --update-freq 8 --eval-inference --best-checkpoint-metric mcd_loss
71
+ ```
72
+
73
+
74
+ ## Inference
75
+ Average the last 5 checkpoints, generate the test split spectrogram and waveform using the default Griffin-Lim vocoder:
76
+ ```bash
77
+ SPLIT=test
78
+ CHECKPOINT_NAME=avg_last_5
79
+ CHECKPOINT_PATH=${SAVE_DIR}/checkpoint_${CHECKPOINT_NAME}.pt
80
+ python scripts/average_checkpoints.py --inputs ${SAVE_DIR} \
81
+ --num-epoch-checkpoints 5 \
82
+ --output ${CHECKPOINT_PATH}
83
+
84
+ python -m examples.speech_synthesis.generate_waveform ${FEATURE_MANIFEST_ROOT} \
85
+ --config-yaml config.yaml --gen-subset ${SPLIT} --task text_to_speech \
86
+ --path ${CHECKPOINT_PATH} --max-tokens 50000 --spec-bwd-max-iter 32 \
87
+ --dump-waveforms
88
+ ```
89
+ which dumps files (waveform, feature, attention plot, etc.) to `${SAVE_DIR}/generate-${CHECKPOINT_NAME}-${SPLIT}`. To
90
+ re-synthesize target waveforms for automatic evaluation, add `--dump-target`.
91
+
92
+ ## Automatic Evaluation
93
+ To start with, generate the manifest for synthetic speech, which will be taken as inputs by evaluation scripts.
94
+ ```bash
95
+ python -m examples.speech_synthesis.evaluation.get_eval_manifest \
96
+ --generation-root ${SAVE_DIR}/generate-${CHECKPOINT_NAME}-${SPLIT} \
97
+ --audio-manifest ${AUDIO_MANIFEST_ROOT}/${SPLIT}.audio.tsv \
98
+ --output-path ${EVAL_OUTPUT_ROOT}/eval.tsv \
99
+ --vocoder griffin_lim --sample-rate 22050 --audio-format flac \
100
+ --use-resynthesized-target
101
+ ```
102
+ Speech recognition (ASR) models usually operate at lower sample rates (e.g. 16kHz). For the WER/CER metric,
103
+ you may need to resample the audios accordingly --- add `--output-sample-rate 16000` for `generate_waveform.py` and
104
+ use `--sample-rate 16000` for `get_eval_manifest.py`.
105
+
106
+
107
+ #### WER/CER metric
108
+ We use wav2vec 2.0 ASR model as example. [Download](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec)
109
+ the model checkpoint and dictionary, then compute WER/CER with
110
+ ```bash
111
+ python -m examples.speech_synthesis.evaluation.eval_asr \
112
+ --audio-header syn --text-header text --err-unit char --split ${SPLIT} \
113
+ --w2v-ckpt ${WAV2VEC2_CHECKPOINT_PATH} --w2v-dict-dir ${WAV2VEC2_DICT_DIR} \
114
+ --raw-manifest ${EVAL_OUTPUT_ROOT}/eval_16khz.tsv --asr-dir ${EVAL_OUTPUT_ROOT}/asr
115
+ ```
116
+
117
+ #### MCD/MSD metric
118
+ ```bash
119
+ python -m examples.speech_synthesis.evaluation.eval_sp \
120
+ ${EVAL_OUTPUT_ROOT}/eval.tsv --mcd --msd
121
+ ```
122
+
123
+ #### F0 metrics
124
+ ```bash
125
+ python -m examples.speech_synthesis.evaluation.eval_f0 \
126
+ ${EVAL_OUTPUT_ROOT}/eval.tsv --gpe --vde --ffe
127
+ ```
128
+
129
+
130
+ ## Results
131
+
132
+ | --arch | Params | Test MCD | Model |
133
+ |---|---|---|---|
134
+ | tts_transformer | 54M | 3.8 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_transformer_phn.tar) |
135
+ | fastspeech2 | 41M | 3.8 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_fastspeech2_phn.tar) |
136
+
137
+ [[Back]](..)