Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- fairseq/examples/hubert/tests/sample.base.L9.npy +3 -0
- fairseq/examples/hubert/tests/sample.large.L20.npy +3 -0
- fairseq/examples/simultaneous_translation/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq/examples/simultaneous_translation/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq/examples/simultaneous_translation/utils/__pycache__/functions.cpython-310.pyc +0 -0
- fairseq/examples/simultaneous_translation/utils/__pycache__/p_choose_strategy.cpython-310.pyc +0 -0
- fairseq/examples/speech_recognition/README.md +87 -0
- fairseq/examples/speech_recognition/__init__.py +1 -0
- fairseq/examples/speech_recognition/criterions/ASG_loss.py +170 -0
- fairseq/examples/speech_recognition/criterions/__init__.py +17 -0
- fairseq/examples/speech_recognition/criterions/cross_entropy_acc.py +130 -0
- fairseq/examples/speech_recognition/data/__init__.py +11 -0
- fairseq/examples/speech_recognition/data/asr_dataset.py +122 -0
- fairseq/examples/speech_recognition/data/collaters.py +131 -0
- fairseq/examples/speech_recognition/data/data_utils.py +100 -0
- fairseq/examples/speech_recognition/data/replabels.py +70 -0
- fairseq/examples/speech_recognition/datasets/asr_prep_json.py +125 -0
- fairseq/examples/speech_recognition/datasets/prepare-librispeech.sh +88 -0
- fairseq/examples/speech_recognition/infer.py +436 -0
- fairseq/examples/speech_recognition/kaldi/__init__.py +0 -0
- fairseq/examples/speech_recognition/kaldi/add-self-loop-simple.cc +94 -0
- fairseq/examples/speech_recognition/kaldi/config/kaldi_initializer.yaml +8 -0
- fairseq/examples/speech_recognition/kaldi/kaldi_decoder.py +244 -0
- fairseq/examples/speech_recognition/kaldi/kaldi_initializer.py +698 -0
- fairseq/examples/speech_recognition/models/__init__.py +8 -0
- fairseq/examples/speech_recognition/models/vggtransformer.py +1020 -0
- fairseq/examples/speech_recognition/models/w2l_conv_glu_enc.py +177 -0
- fairseq/examples/speech_recognition/new/README.md +43 -0
- fairseq/examples/speech_recognition/new/__init__.py +0 -0
- fairseq/examples/speech_recognition/new/conf/hydra/sweeper/ax.yaml +29 -0
- fairseq/examples/speech_recognition/new/conf/hydra/sweeper/ax_sil.yaml +29 -0
- fairseq/examples/speech_recognition/new/conf/infer.yaml +27 -0
- fairseq/examples/speech_recognition/new/conf/run_config/fb_slurm_1.yaml +28 -0
- fairseq/examples/speech_recognition/new/conf/run_config/fb_slurm_2g.yaml +27 -0
- fairseq/examples/speech_recognition/new/decoders/__init__.py +0 -0
- fairseq/examples/speech_recognition/new/decoders/base_decoder.py +62 -0
- fairseq/examples/speech_recognition/new/decoders/decoder.py +32 -0
- fairseq/examples/speech_recognition/new/decoders/decoder_config.py +70 -0
- fairseq/examples/speech_recognition/new/decoders/flashlight_decoder.py +433 -0
- fairseq/examples/speech_recognition/new/decoders/viterbi_decoder.py +24 -0
- fairseq/examples/speech_recognition/new/infer.py +502 -0
- fairseq/examples/speech_recognition/tasks/__init__.py +8 -0
- fairseq/examples/speech_recognition/tasks/speech_recognition.py +157 -0
- fairseq/examples/speech_recognition/utils/wer_utils.py +381 -0
- fairseq/examples/speech_recognition/w2l_decoder.py +486 -0
- fairseq/examples/speech_synthesis/README.md +38 -0
- fairseq/examples/speech_synthesis/__init__.py +4 -0
- fairseq/examples/speech_synthesis/data_utils.py +344 -0
- fairseq/examples/speech_synthesis/docs/common_voice_example.md +67 -0
- fairseq/examples/speech_synthesis/docs/ljspeech_example.md +137 -0
fairseq/examples/hubert/tests/sample.base.L9.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b44dc3a0519f6adb670a5da7410c1405f4fdb0e4866e5fc4983b136480212f34
|
3 |
+
size 1831104
|
fairseq/examples/hubert/tests/sample.large.L20.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4c82839134cc2340355eb49b41e7e3517e8e9dfdfaa6fc28f0464cc8ae9569ee
|
3 |
+
size 2441408
|
fairseq/examples/simultaneous_translation/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (258 Bytes). View file
|
|
fairseq/examples/simultaneous_translation/utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (509 Bytes). View file
|
|
fairseq/examples/simultaneous_translation/utils/__pycache__/functions.cpython-310.pyc
ADDED
Binary file (3.29 kB). View file
|
|
fairseq/examples/simultaneous_translation/utils/__pycache__/p_choose_strategy.cpython-310.pyc
ADDED
Binary file (1.79 kB). View file
|
|
fairseq/examples/speech_recognition/README.md
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### 2021 Update: We are merging this example into the [S2T framework](../speech_to_text), which supports more generic speech-to-text tasks (e.g. speech translation) and more flexible data processing pipelines. Please stay tuned.
|
2 |
+
|
3 |
+
# Speech Recognition
|
4 |
+
`examples/speech_recognition` is implementing ASR task in Fairseq, along with needed features, datasets, models and loss functions to train and infer model described in [Transformers with convolutional context for ASR (Abdelrahman Mohamed et al., 2019)](https://arxiv.org/abs/1904.11660).
|
5 |
+
|
6 |
+
|
7 |
+
## Additional dependencies
|
8 |
+
On top of main fairseq dependencies there are couple more additional requirements.
|
9 |
+
|
10 |
+
1) Please follow the instructions to install [torchaudio](https://github.com/pytorch/audio). This is required to compute audio fbank features.
|
11 |
+
2) [Sclite](http://www1.icsi.berkeley.edu/Speech/docs/sctk-1.2/sclite.htm#sclite_name_0) is used to measure WER. Sclite can be downloaded and installed from source from sctk package [here](http://www.openslr.org/4/). Training and inference doesn't require Sclite dependency.
|
12 |
+
3) [sentencepiece](https://github.com/google/sentencepiece) is required in order to create dataset with word-piece targets.
|
13 |
+
|
14 |
+
## Preparing librispeech data
|
15 |
+
```
|
16 |
+
./examples/speech_recognition/datasets/prepare-librispeech.sh $DIR_TO_SAVE_RAW_DATA $DIR_FOR_PREPROCESSED_DATA
|
17 |
+
```
|
18 |
+
|
19 |
+
## Training librispeech data
|
20 |
+
```
|
21 |
+
python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 80 --task speech_recognition --arch vggtransformer_2 --optimizer adadelta --lr 1.0 --adadelta-eps 1e-8 --adadelta-rho 0.95 --clip-norm 10.0 --max-tokens 5000 --log-format json --log-interval 1 --criterion cross_entropy_acc --user-dir examples/speech_recognition/
|
22 |
+
```
|
23 |
+
|
24 |
+
## Inference for librispeech
|
25 |
+
`$SET` can be `test_clean` or `test_other`
|
26 |
+
Any checkpoint in `$MODEL_PATH` can be selected. In this example we are working with `checkpoint_last.pt`
|
27 |
+
```
|
28 |
+
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --max-tokens 25000 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --beam 20 --results-path $RES_DIR --batch-size 40 --gen-subset $SET --user-dir examples/speech_recognition/
|
29 |
+
```
|
30 |
+
|
31 |
+
## Inference for librispeech
|
32 |
+
```
|
33 |
+
sclite -r ${RES_DIR}/ref.word-checkpoint_last.pt-${SET}.txt -h ${RES_DIR}/hypo.word-checkpoint_last.pt-${SET}.txt -i rm -o all stdout > $RES_REPORT
|
34 |
+
```
|
35 |
+
`Sum/Avg` row from first table of the report has WER
|
36 |
+
|
37 |
+
## Using flashlight (previously called [wav2letter](https://github.com/facebookresearch/wav2letter)) components
|
38 |
+
[flashlight](https://github.com/facebookresearch/flashlight) now has integration with fairseq. Currently this includes:
|
39 |
+
|
40 |
+
* AutoSegmentationCriterion (ASG)
|
41 |
+
* flashlight-style Conv/GLU model
|
42 |
+
* flashlight's beam search decoder
|
43 |
+
|
44 |
+
To use these, follow the instructions on [this page](https://github.com/flashlight/flashlight/tree/e16682fa32df30cbf675c8fe010f929c61e3b833/bindings/python) to install python bindings. **Flashlight v0.3.2** must be used to install the bindings. Running:
|
45 |
+
```
|
46 |
+
git clone --branch v0.3.2 https://github.com/flashlight/flashlight
|
47 |
+
```
|
48 |
+
will properly clone and check out this version.
|
49 |
+
|
50 |
+
## Training librispeech data (flashlight style, Conv/GLU + ASG loss)
|
51 |
+
Training command:
|
52 |
+
```
|
53 |
+
python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 100 --task speech_recognition --arch w2l_conv_glu_enc --batch-size 4 --optimizer sgd --lr 0.3,0.8 --momentum 0.8 --clip-norm 0.2 --max-tokens 50000 --log-format json --log-interval 100 --num-workers 0 --sentence-avg --criterion asg_loss --asg-transitions-init 5 --max-replabel 2 --linseg-updates 8789 --user-dir examples/speech_recognition
|
54 |
+
```
|
55 |
+
|
56 |
+
Note that ASG loss currently doesn't do well with word-pieces. You should prepare a dataset with character targets by setting `nbpe=31` in `prepare-librispeech.sh`.
|
57 |
+
|
58 |
+
## Inference for librispeech (flashlight decoder, n-gram LM)
|
59 |
+
Inference command:
|
60 |
+
```
|
61 |
+
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder kenlm --kenlm-model $KENLM_MODEL_PATH --lexicon $LEXICON_PATH --beam 200 --beam-threshold 15 --lm-weight 1.5 --word-score 1.5 --sil-weight -0.3 --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition
|
62 |
+
```
|
63 |
+
|
64 |
+
`$KENLM_MODEL_PATH` should be a standard n-gram language model file. `$LEXICON_PATH` should be a flashlight-style lexicon (list of known words and their spellings). For ASG inference, a lexicon line should look like this (note the repetition labels):
|
65 |
+
```
|
66 |
+
doorbell D O 1 R B E L 1 ▁
|
67 |
+
```
|
68 |
+
For CTC inference with word-pieces, repetition labels are not used and the lexicon should have most common spellings for each word (one can use sentencepiece's `NBestEncodeAsPieces` for this):
|
69 |
+
```
|
70 |
+
doorbell ▁DOOR BE LL
|
71 |
+
doorbell ▁DOOR B E LL
|
72 |
+
doorbell ▁DO OR BE LL
|
73 |
+
doorbell ▁DOOR B EL L
|
74 |
+
doorbell ▁DOOR BE L L
|
75 |
+
doorbell ▁DO OR B E LL
|
76 |
+
doorbell ▁DOOR B E L L
|
77 |
+
doorbell ▁DO OR B EL L
|
78 |
+
doorbell ▁DO O R BE LL
|
79 |
+
doorbell ▁DO OR BE L L
|
80 |
+
```
|
81 |
+
Lowercase vs. uppercase matters: the *word* should match the case of the n-gram language model (i.e. `$KENLM_MODEL_PATH`), while the *spelling* should match the case of the token dictionary (i.e. `$DIR_FOR_PREPROCESSED_DATA/dict.txt`).
|
82 |
+
|
83 |
+
## Inference for librispeech (flashlight decoder, viterbi only)
|
84 |
+
Inference command:
|
85 |
+
```
|
86 |
+
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder viterbi --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition
|
87 |
+
```
|
fairseq/examples/speech_recognition/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import criterions, models, tasks # noqa
|
fairseq/examples/speech_recognition/criterions/ASG_loss.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the MIT license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from examples.speech_recognition.data.replabels import pack_replabels
|
10 |
+
from fairseq import utils
|
11 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
12 |
+
|
13 |
+
|
14 |
+
@register_criterion("asg_loss")
|
15 |
+
class ASGCriterion(FairseqCriterion):
|
16 |
+
@staticmethod
|
17 |
+
def add_args(parser):
|
18 |
+
group = parser.add_argument_group("ASG Loss")
|
19 |
+
group.add_argument(
|
20 |
+
"--asg-transitions-init",
|
21 |
+
help="initial diagonal value of transition matrix",
|
22 |
+
type=float,
|
23 |
+
default=0.0,
|
24 |
+
)
|
25 |
+
group.add_argument(
|
26 |
+
"--max-replabel", help="maximum # of replabels", type=int, default=2
|
27 |
+
)
|
28 |
+
group.add_argument(
|
29 |
+
"--linseg-updates",
|
30 |
+
help="# of training updates to use LinSeg initialization",
|
31 |
+
type=int,
|
32 |
+
default=0,
|
33 |
+
)
|
34 |
+
group.add_argument(
|
35 |
+
"--hide-linseg-messages",
|
36 |
+
help="hide messages about LinSeg initialization",
|
37 |
+
action="store_true",
|
38 |
+
)
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
task,
|
43 |
+
silence_token,
|
44 |
+
asg_transitions_init,
|
45 |
+
max_replabel,
|
46 |
+
linseg_updates,
|
47 |
+
hide_linseg_messages,
|
48 |
+
):
|
49 |
+
from flashlight.lib.sequence.criterion import ASGLoss, CriterionScaleMode
|
50 |
+
|
51 |
+
super().__init__(task)
|
52 |
+
self.tgt_dict = task.target_dictionary
|
53 |
+
self.eos = self.tgt_dict.eos()
|
54 |
+
self.silence = (
|
55 |
+
self.tgt_dict.index(silence_token)
|
56 |
+
if silence_token in self.tgt_dict
|
57 |
+
else None
|
58 |
+
)
|
59 |
+
self.max_replabel = max_replabel
|
60 |
+
|
61 |
+
num_labels = len(self.tgt_dict)
|
62 |
+
self.asg = ASGLoss(num_labels, scale_mode=CriterionScaleMode.TARGET_SZ_SQRT)
|
63 |
+
self.asg.trans = torch.nn.Parameter(
|
64 |
+
asg_transitions_init * torch.eye(num_labels), requires_grad=True
|
65 |
+
)
|
66 |
+
|
67 |
+
self.linseg_progress = torch.nn.Parameter(
|
68 |
+
torch.tensor([0], dtype=torch.int), requires_grad=False
|
69 |
+
)
|
70 |
+
self.linseg_maximum = linseg_updates
|
71 |
+
self.linseg_message_state = "none" if hide_linseg_messages else "start"
|
72 |
+
|
73 |
+
@classmethod
|
74 |
+
def build_criterion(cls, args, task):
|
75 |
+
return cls(
|
76 |
+
task,
|
77 |
+
args.silence_token,
|
78 |
+
args.asg_transitions_init,
|
79 |
+
args.max_replabel,
|
80 |
+
args.linseg_updates,
|
81 |
+
args.hide_linseg_messages,
|
82 |
+
)
|
83 |
+
|
84 |
+
def linseg_step(self):
|
85 |
+
if not self.training:
|
86 |
+
return False
|
87 |
+
if self.linseg_progress.item() < self.linseg_maximum:
|
88 |
+
if self.linseg_message_state == "start":
|
89 |
+
print("| using LinSeg to initialize ASG")
|
90 |
+
self.linseg_message_state = "finish"
|
91 |
+
self.linseg_progress.add_(1)
|
92 |
+
return True
|
93 |
+
elif self.linseg_message_state == "finish":
|
94 |
+
print("| finished LinSeg initialization")
|
95 |
+
self.linseg_message_state = "none"
|
96 |
+
return False
|
97 |
+
|
98 |
+
def replace_eos_with_silence(self, tgt):
|
99 |
+
if tgt[-1] != self.eos:
|
100 |
+
return tgt
|
101 |
+
elif self.silence is None or (len(tgt) > 1 and tgt[-2] == self.silence):
|
102 |
+
return tgt[:-1]
|
103 |
+
else:
|
104 |
+
return tgt[:-1] + [self.silence]
|
105 |
+
|
106 |
+
def forward(self, model, sample, reduce=True):
|
107 |
+
"""Compute the loss for the given sample.
|
108 |
+
|
109 |
+
Returns a tuple with three elements:
|
110 |
+
1) the loss
|
111 |
+
2) the sample size, which is used as the denominator for the gradient
|
112 |
+
3) logging outputs to display while training
|
113 |
+
"""
|
114 |
+
|
115 |
+
net_output = model(**sample["net_input"])
|
116 |
+
emissions = net_output["encoder_out"].transpose(0, 1).contiguous()
|
117 |
+
B = emissions.size(0)
|
118 |
+
T = emissions.size(1)
|
119 |
+
device = emissions.device
|
120 |
+
|
121 |
+
target = torch.IntTensor(B, T)
|
122 |
+
target_size = torch.IntTensor(B)
|
123 |
+
using_linseg = self.linseg_step()
|
124 |
+
|
125 |
+
for b in range(B):
|
126 |
+
initial_target_size = sample["target_lengths"][b].item()
|
127 |
+
if initial_target_size == 0:
|
128 |
+
raise ValueError("target size cannot be zero")
|
129 |
+
|
130 |
+
tgt = sample["target"][b, :initial_target_size].tolist()
|
131 |
+
tgt = self.replace_eos_with_silence(tgt)
|
132 |
+
tgt = pack_replabels(tgt, self.tgt_dict, self.max_replabel)
|
133 |
+
tgt = tgt[:T]
|
134 |
+
|
135 |
+
if using_linseg:
|
136 |
+
tgt = [tgt[t * len(tgt) // T] for t in range(T)]
|
137 |
+
|
138 |
+
target[b][: len(tgt)] = torch.IntTensor(tgt)
|
139 |
+
target_size[b] = len(tgt)
|
140 |
+
|
141 |
+
loss = self.asg.forward(emissions, target.to(device), target_size.to(device))
|
142 |
+
|
143 |
+
if reduce:
|
144 |
+
loss = torch.sum(loss)
|
145 |
+
|
146 |
+
sample_size = (
|
147 |
+
sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
|
148 |
+
)
|
149 |
+
logging_output = {
|
150 |
+
"loss": utils.item(loss.data) if reduce else loss.data,
|
151 |
+
"ntokens": sample["ntokens"],
|
152 |
+
"nsentences": sample["target"].size(0),
|
153 |
+
"sample_size": sample_size,
|
154 |
+
}
|
155 |
+
return loss, sample_size, logging_output
|
156 |
+
|
157 |
+
@staticmethod
|
158 |
+
def aggregate_logging_outputs(logging_outputs):
|
159 |
+
"""Aggregate logging outputs from data parallel training."""
|
160 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
161 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
162 |
+
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
163 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
164 |
+
agg_output = {
|
165 |
+
"loss": loss_sum / nsentences,
|
166 |
+
"ntokens": ntokens,
|
167 |
+
"nsentences": nsentences,
|
168 |
+
"sample_size": sample_size,
|
169 |
+
}
|
170 |
+
return agg_output
|
fairseq/examples/speech_recognition/criterions/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import os
|
3 |
+
|
4 |
+
|
5 |
+
# ASG loss requires flashlight bindings
|
6 |
+
files_to_skip = set()
|
7 |
+
try:
|
8 |
+
import flashlight.lib.sequence.criterion
|
9 |
+
except ImportError:
|
10 |
+
files_to_skip.add("ASG_loss.py")
|
11 |
+
|
12 |
+
for file in sorted(os.listdir(os.path.dirname(__file__))):
|
13 |
+
if file.endswith(".py") and not file.startswith("_") and file not in files_to_skip:
|
14 |
+
criterion_name = file[: file.find(".py")]
|
15 |
+
importlib.import_module(
|
16 |
+
"examples.speech_recognition.criterions." + criterion_name
|
17 |
+
)
|
fairseq/examples/speech_recognition/criterions/cross_entropy_acc.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import math
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from fairseq import utils
|
14 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
15 |
+
|
16 |
+
|
17 |
+
@register_criterion("cross_entropy_acc")
|
18 |
+
class CrossEntropyWithAccCriterion(FairseqCriterion):
|
19 |
+
def __init__(self, task, sentence_avg):
|
20 |
+
super().__init__(task)
|
21 |
+
self.sentence_avg = sentence_avg
|
22 |
+
|
23 |
+
def compute_loss(self, model, net_output, target, reduction, log_probs):
|
24 |
+
# N, T -> N * T
|
25 |
+
target = target.view(-1)
|
26 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=log_probs)
|
27 |
+
if not hasattr(lprobs, "batch_first"):
|
28 |
+
logging.warning(
|
29 |
+
"ERROR: we need to know whether "
|
30 |
+
"batch first for the net output; "
|
31 |
+
"you need to set batch_first attribute for the return value of "
|
32 |
+
"model.get_normalized_probs. Now, we assume this is true, but "
|
33 |
+
"in the future, we will raise exception instead. "
|
34 |
+
)
|
35 |
+
batch_first = getattr(lprobs, "batch_first", True)
|
36 |
+
if not batch_first:
|
37 |
+
lprobs = lprobs.transpose(0, 1)
|
38 |
+
|
39 |
+
# N, T, D -> N * T, D
|
40 |
+
lprobs = lprobs.view(-1, lprobs.size(-1))
|
41 |
+
loss = F.nll_loss(
|
42 |
+
lprobs, target, ignore_index=self.padding_idx, reduction=reduction
|
43 |
+
)
|
44 |
+
return lprobs, loss
|
45 |
+
|
46 |
+
def get_logging_output(self, sample, target, lprobs, loss):
|
47 |
+
target = target.view(-1)
|
48 |
+
mask = target != self.padding_idx
|
49 |
+
correct = torch.sum(
|
50 |
+
lprobs.argmax(1).masked_select(mask) == target.masked_select(mask)
|
51 |
+
)
|
52 |
+
total = torch.sum(mask)
|
53 |
+
sample_size = (
|
54 |
+
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
55 |
+
)
|
56 |
+
|
57 |
+
logging_output = {
|
58 |
+
"loss": utils.item(loss.data), # * sample['ntokens'],
|
59 |
+
"ntokens": sample["ntokens"],
|
60 |
+
"nsentences": sample["target"].size(0),
|
61 |
+
"sample_size": sample_size,
|
62 |
+
"correct": utils.item(correct.data),
|
63 |
+
"total": utils.item(total.data),
|
64 |
+
"nframes": torch.sum(sample["net_input"]["src_lengths"]).item(),
|
65 |
+
}
|
66 |
+
|
67 |
+
return sample_size, logging_output
|
68 |
+
|
69 |
+
def forward(self, model, sample, reduction="sum", log_probs=True):
|
70 |
+
"""Computes the cross entropy with accuracy metric for the given sample.
|
71 |
+
|
72 |
+
This is similar to CrossEntropyCriterion in fairseq, but also
|
73 |
+
computes accuracy metrics as part of logging
|
74 |
+
|
75 |
+
Args:
|
76 |
+
logprobs (Torch.tensor) of shape N, T, D i.e.
|
77 |
+
batchsize, timesteps, dimensions
|
78 |
+
targets (Torch.tensor) of shape N, T i.e batchsize, timesteps
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
tuple: With three elements:
|
82 |
+
1) the loss
|
83 |
+
2) the sample size, which is used as the denominator for the gradient
|
84 |
+
3) logging outputs to display while training
|
85 |
+
|
86 |
+
TODO:
|
87 |
+
* Currently this Criterion will only work with LSTMEncoderModels or
|
88 |
+
FairseqModels which have decoder, or Models which return TorchTensor
|
89 |
+
as net_output.
|
90 |
+
We need to make a change to support all FairseqEncoder models.
|
91 |
+
"""
|
92 |
+
net_output = model(**sample["net_input"])
|
93 |
+
target = model.get_targets(sample, net_output)
|
94 |
+
lprobs, loss = self.compute_loss(
|
95 |
+
model, net_output, target, reduction, log_probs
|
96 |
+
)
|
97 |
+
sample_size, logging_output = self.get_logging_output(
|
98 |
+
sample, target, lprobs, loss
|
99 |
+
)
|
100 |
+
return loss, sample_size, logging_output
|
101 |
+
|
102 |
+
@staticmethod
|
103 |
+
def aggregate_logging_outputs(logging_outputs):
|
104 |
+
"""Aggregate logging outputs from data parallel training."""
|
105 |
+
correct_sum = sum(log.get("correct", 0) for log in logging_outputs)
|
106 |
+
total_sum = sum(log.get("total", 0) for log in logging_outputs)
|
107 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
108 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
109 |
+
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
110 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
111 |
+
nframes = sum(log.get("nframes", 0) for log in logging_outputs)
|
112 |
+
agg_output = {
|
113 |
+
"loss": loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0,
|
114 |
+
# if args.sentence_avg, then sample_size is nsentences, then loss
|
115 |
+
# is per-sentence loss; else sample_size is ntokens, the loss
|
116 |
+
# becomes per-output token loss
|
117 |
+
"ntokens": ntokens,
|
118 |
+
"nsentences": nsentences,
|
119 |
+
"nframes": nframes,
|
120 |
+
"sample_size": sample_size,
|
121 |
+
"acc": correct_sum * 100.0 / total_sum if total_sum > 0 else 0.0,
|
122 |
+
"correct": correct_sum,
|
123 |
+
"total": total_sum,
|
124 |
+
# total is the number of validate tokens
|
125 |
+
}
|
126 |
+
if sample_size != ntokens:
|
127 |
+
agg_output["nll_loss"] = loss_sum / ntokens / math.log(2)
|
128 |
+
# loss: per output token loss
|
129 |
+
# nll_loss: per sentence loss
|
130 |
+
return agg_output
|
fairseq/examples/speech_recognition/data/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .asr_dataset import AsrDataset
|
7 |
+
|
8 |
+
|
9 |
+
__all__ = [
|
10 |
+
"AsrDataset",
|
11 |
+
]
|
fairseq/examples/speech_recognition/data/asr_dataset.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from fairseq.data import FairseqDataset
|
10 |
+
|
11 |
+
from . import data_utils
|
12 |
+
from .collaters import Seq2SeqCollater
|
13 |
+
|
14 |
+
|
15 |
+
class AsrDataset(FairseqDataset):
|
16 |
+
"""
|
17 |
+
A dataset representing speech and corresponding transcription.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
aud_paths: (List[str]): A list of str with paths to audio files.
|
21 |
+
aud_durations_ms (List[int]): A list of int containing the durations of
|
22 |
+
audio files.
|
23 |
+
tgt (List[torch.LongTensor]): A list of LongTensors containing the indices
|
24 |
+
of target transcriptions.
|
25 |
+
tgt_dict (~fairseq.data.Dictionary): target vocabulary.
|
26 |
+
ids (List[str]): A list of utterance IDs.
|
27 |
+
speakers (List[str]): A list of speakers corresponding to utterances.
|
28 |
+
num_mel_bins (int): Number of triangular mel-frequency bins (default: 80)
|
29 |
+
frame_length (float): Frame length in milliseconds (default: 25.0)
|
30 |
+
frame_shift (float): Frame shift in milliseconds (default: 10.0)
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
aud_paths,
|
36 |
+
aud_durations_ms,
|
37 |
+
tgt,
|
38 |
+
tgt_dict,
|
39 |
+
ids,
|
40 |
+
speakers,
|
41 |
+
num_mel_bins=80,
|
42 |
+
frame_length=25.0,
|
43 |
+
frame_shift=10.0,
|
44 |
+
):
|
45 |
+
assert frame_length > 0
|
46 |
+
assert frame_shift > 0
|
47 |
+
assert all(x > frame_length for x in aud_durations_ms)
|
48 |
+
self.frame_sizes = [
|
49 |
+
int(1 + (d - frame_length) / frame_shift) for d in aud_durations_ms
|
50 |
+
]
|
51 |
+
|
52 |
+
assert len(aud_paths) > 0
|
53 |
+
assert len(aud_paths) == len(aud_durations_ms)
|
54 |
+
assert len(aud_paths) == len(tgt)
|
55 |
+
assert len(aud_paths) == len(ids)
|
56 |
+
assert len(aud_paths) == len(speakers)
|
57 |
+
self.aud_paths = aud_paths
|
58 |
+
self.tgt_dict = tgt_dict
|
59 |
+
self.tgt = tgt
|
60 |
+
self.ids = ids
|
61 |
+
self.speakers = speakers
|
62 |
+
self.num_mel_bins = num_mel_bins
|
63 |
+
self.frame_length = frame_length
|
64 |
+
self.frame_shift = frame_shift
|
65 |
+
|
66 |
+
self.s2s_collater = Seq2SeqCollater(
|
67 |
+
0,
|
68 |
+
1,
|
69 |
+
pad_index=self.tgt_dict.pad(),
|
70 |
+
eos_index=self.tgt_dict.eos(),
|
71 |
+
move_eos_to_beginning=True,
|
72 |
+
)
|
73 |
+
|
74 |
+
def __getitem__(self, index):
|
75 |
+
import torchaudio
|
76 |
+
import torchaudio.compliance.kaldi as kaldi
|
77 |
+
|
78 |
+
tgt_item = self.tgt[index] if self.tgt is not None else None
|
79 |
+
|
80 |
+
path = self.aud_paths[index]
|
81 |
+
if not os.path.exists(path):
|
82 |
+
raise FileNotFoundError("Audio file not found: {}".format(path))
|
83 |
+
sound, sample_rate = torchaudio.load_wav(path)
|
84 |
+
output = kaldi.fbank(
|
85 |
+
sound,
|
86 |
+
num_mel_bins=self.num_mel_bins,
|
87 |
+
frame_length=self.frame_length,
|
88 |
+
frame_shift=self.frame_shift,
|
89 |
+
)
|
90 |
+
output_cmvn = data_utils.apply_mv_norm(output)
|
91 |
+
|
92 |
+
return {"id": index, "data": [output_cmvn.detach(), tgt_item]}
|
93 |
+
|
94 |
+
def __len__(self):
|
95 |
+
return len(self.aud_paths)
|
96 |
+
|
97 |
+
def collater(self, samples):
|
98 |
+
"""Merge a list of samples to form a mini-batch.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
samples (List[int]): sample indices to collate
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
dict: a mini-batch suitable for forwarding with a Model
|
105 |
+
"""
|
106 |
+
return self.s2s_collater.collate(samples)
|
107 |
+
|
108 |
+
def num_tokens(self, index):
|
109 |
+
return self.frame_sizes[index]
|
110 |
+
|
111 |
+
def size(self, index):
|
112 |
+
"""Return an example's size as a float or tuple. This value is used when
|
113 |
+
filtering a dataset with ``--max-positions``."""
|
114 |
+
return (
|
115 |
+
self.frame_sizes[index],
|
116 |
+
len(self.tgt[index]) if self.tgt is not None else 0,
|
117 |
+
)
|
118 |
+
|
119 |
+
def ordered_indices(self):
|
120 |
+
"""Return an ordered list of indices. Batches will be constructed based
|
121 |
+
on this order."""
|
122 |
+
return np.arange(len(self))
|
fairseq/examples/speech_recognition/data/collaters.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
"""
|
6 |
+
This module contains collection of classes which implement
|
7 |
+
collate functionalities for various tasks.
|
8 |
+
|
9 |
+
Collaters should know what data to expect for each sample
|
10 |
+
and they should pack / collate them into batches
|
11 |
+
"""
|
12 |
+
|
13 |
+
|
14 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
from fairseq.data import data_utils as fairseq_data_utils
|
19 |
+
|
20 |
+
|
21 |
+
class Seq2SeqCollater(object):
|
22 |
+
"""
|
23 |
+
Implements collate function mainly for seq2seq tasks
|
24 |
+
This expects each sample to contain feature (src_tokens) and
|
25 |
+
targets.
|
26 |
+
This collator is also used for aligned training task.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
feature_index=0,
|
32 |
+
label_index=1,
|
33 |
+
pad_index=1,
|
34 |
+
eos_index=2,
|
35 |
+
move_eos_to_beginning=True,
|
36 |
+
):
|
37 |
+
self.feature_index = feature_index
|
38 |
+
self.label_index = label_index
|
39 |
+
self.pad_index = pad_index
|
40 |
+
self.eos_index = eos_index
|
41 |
+
self.move_eos_to_beginning = move_eos_to_beginning
|
42 |
+
|
43 |
+
def _collate_frames(self, frames):
|
44 |
+
"""Convert a list of 2d frames into a padded 3d tensor
|
45 |
+
Args:
|
46 |
+
frames (list): list of 2d frames of size L[i]*f_dim. Where L[i] is
|
47 |
+
length of i-th frame and f_dim is static dimension of features
|
48 |
+
Returns:
|
49 |
+
3d tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
|
50 |
+
"""
|
51 |
+
len_max = max(frame.size(0) for frame in frames)
|
52 |
+
f_dim = frames[0].size(1)
|
53 |
+
res = frames[0].new(len(frames), len_max, f_dim).fill_(0.0)
|
54 |
+
|
55 |
+
for i, v in enumerate(frames):
|
56 |
+
res[i, : v.size(0)] = v
|
57 |
+
|
58 |
+
return res
|
59 |
+
|
60 |
+
def collate(self, samples):
|
61 |
+
"""
|
62 |
+
utility function to collate samples into batch for speech recognition.
|
63 |
+
"""
|
64 |
+
if len(samples) == 0:
|
65 |
+
return {}
|
66 |
+
|
67 |
+
# parse samples into torch tensors
|
68 |
+
parsed_samples = []
|
69 |
+
for s in samples:
|
70 |
+
# skip invalid samples
|
71 |
+
if s["data"][self.feature_index] is None:
|
72 |
+
continue
|
73 |
+
source = s["data"][self.feature_index]
|
74 |
+
if isinstance(source, (np.ndarray, np.generic)):
|
75 |
+
source = torch.from_numpy(source)
|
76 |
+
target = s["data"][self.label_index]
|
77 |
+
if isinstance(target, (np.ndarray, np.generic)):
|
78 |
+
target = torch.from_numpy(target).long()
|
79 |
+
elif isinstance(target, list):
|
80 |
+
target = torch.LongTensor(target)
|
81 |
+
|
82 |
+
parsed_sample = {"id": s["id"], "source": source, "target": target}
|
83 |
+
parsed_samples.append(parsed_sample)
|
84 |
+
samples = parsed_samples
|
85 |
+
|
86 |
+
id = torch.LongTensor([s["id"] for s in samples])
|
87 |
+
frames = self._collate_frames([s["source"] for s in samples])
|
88 |
+
# sort samples by descending number of frames
|
89 |
+
frames_lengths = torch.LongTensor([s["source"].size(0) for s in samples])
|
90 |
+
frames_lengths, sort_order = frames_lengths.sort(descending=True)
|
91 |
+
id = id.index_select(0, sort_order)
|
92 |
+
frames = frames.index_select(0, sort_order)
|
93 |
+
|
94 |
+
target = None
|
95 |
+
target_lengths = None
|
96 |
+
prev_output_tokens = None
|
97 |
+
if samples[0].get("target", None) is not None:
|
98 |
+
ntokens = sum(len(s["target"]) for s in samples)
|
99 |
+
target = fairseq_data_utils.collate_tokens(
|
100 |
+
[s["target"] for s in samples],
|
101 |
+
self.pad_index,
|
102 |
+
self.eos_index,
|
103 |
+
left_pad=False,
|
104 |
+
move_eos_to_beginning=False,
|
105 |
+
)
|
106 |
+
target = target.index_select(0, sort_order)
|
107 |
+
target_lengths = torch.LongTensor(
|
108 |
+
[s["target"].size(0) for s in samples]
|
109 |
+
).index_select(0, sort_order)
|
110 |
+
prev_output_tokens = fairseq_data_utils.collate_tokens(
|
111 |
+
[s["target"] for s in samples],
|
112 |
+
self.pad_index,
|
113 |
+
self.eos_index,
|
114 |
+
left_pad=False,
|
115 |
+
move_eos_to_beginning=self.move_eos_to_beginning,
|
116 |
+
)
|
117 |
+
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
|
118 |
+
else:
|
119 |
+
ntokens = sum(len(s["source"]) for s in samples)
|
120 |
+
|
121 |
+
batch = {
|
122 |
+
"id": id,
|
123 |
+
"ntokens": ntokens,
|
124 |
+
"net_input": {"src_tokens": frames, "src_lengths": frames_lengths},
|
125 |
+
"target": target,
|
126 |
+
"target_lengths": target_lengths,
|
127 |
+
"nsentences": len(samples),
|
128 |
+
}
|
129 |
+
if prev_output_tokens is not None:
|
130 |
+
batch["net_input"]["prev_output_tokens"] = prev_output_tokens
|
131 |
+
return batch
|
fairseq/examples/speech_recognition/data/data_utils.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def calc_mean_invstddev(feature):
|
10 |
+
if len(feature.size()) != 2:
|
11 |
+
raise ValueError("We expect the input feature to be 2-D tensor")
|
12 |
+
mean = feature.mean(0)
|
13 |
+
var = feature.var(0)
|
14 |
+
# avoid division by ~zero
|
15 |
+
eps = 1e-8
|
16 |
+
if (var < eps).any():
|
17 |
+
return mean, 1.0 / (torch.sqrt(var) + eps)
|
18 |
+
return mean, 1.0 / torch.sqrt(var)
|
19 |
+
|
20 |
+
|
21 |
+
def apply_mv_norm(features):
|
22 |
+
# If there is less than 2 spectrograms, the variance cannot be computed (is NaN)
|
23 |
+
# and normalization is not possible, so return the item as it is
|
24 |
+
if features.size(0) < 2:
|
25 |
+
return features
|
26 |
+
mean, invstddev = calc_mean_invstddev(features)
|
27 |
+
res = (features - mean) * invstddev
|
28 |
+
return res
|
29 |
+
|
30 |
+
|
31 |
+
def lengths_to_encoder_padding_mask(lengths, batch_first=False):
|
32 |
+
"""
|
33 |
+
convert lengths (a 1-D Long/Int tensor) to 2-D binary tensor
|
34 |
+
|
35 |
+
Args:
|
36 |
+
lengths: a (B, )-shaped tensor
|
37 |
+
|
38 |
+
Return:
|
39 |
+
max_length: maximum length of B sequences
|
40 |
+
encoder_padding_mask: a (max_length, B) binary mask, where
|
41 |
+
[t, b] = 0 for t < lengths[b] and 1 otherwise
|
42 |
+
|
43 |
+
TODO:
|
44 |
+
kernelize this function if benchmarking shows this function is slow
|
45 |
+
"""
|
46 |
+
max_lengths = torch.max(lengths).item()
|
47 |
+
bsz = lengths.size(0)
|
48 |
+
encoder_padding_mask = torch.arange(
|
49 |
+
max_lengths
|
50 |
+
).to( # a (T, ) tensor with [0, ..., T-1]
|
51 |
+
lengths.device
|
52 |
+
).view( # move to the right device
|
53 |
+
1, max_lengths
|
54 |
+
).expand( # reshape to (1, T)-shaped tensor
|
55 |
+
bsz, -1
|
56 |
+
) >= lengths.view( # expand to (B, T)-shaped tensor
|
57 |
+
bsz, 1
|
58 |
+
).expand(
|
59 |
+
-1, max_lengths
|
60 |
+
)
|
61 |
+
if not batch_first:
|
62 |
+
return encoder_padding_mask.t(), max_lengths
|
63 |
+
else:
|
64 |
+
return encoder_padding_mask, max_lengths
|
65 |
+
|
66 |
+
|
67 |
+
def encoder_padding_mask_to_lengths(
|
68 |
+
encoder_padding_mask, max_lengths, batch_size, device
|
69 |
+
):
|
70 |
+
"""
|
71 |
+
convert encoder_padding_mask (2-D binary tensor) to a 1-D tensor
|
72 |
+
|
73 |
+
Conventionally, encoder output contains a encoder_padding_mask, which is
|
74 |
+
a 2-D mask in a shape (T, B), whose (t, b) element indicate whether
|
75 |
+
encoder_out[t, b] is a valid output (=0) or not (=1). Occasionally, we
|
76 |
+
need to convert this mask tensor to a 1-D tensor in shape (B, ), where
|
77 |
+
[b] denotes the valid length of b-th sequence
|
78 |
+
|
79 |
+
Args:
|
80 |
+
encoder_padding_mask: a (T, B)-shaped binary tensor or None; if None,
|
81 |
+
indicating all are valid
|
82 |
+
Return:
|
83 |
+
seq_lengths: a (B,)-shaped tensor, where its (b, )-th element is the
|
84 |
+
number of valid elements of b-th sequence
|
85 |
+
|
86 |
+
max_lengths: maximum length of all sequence, if encoder_padding_mask is
|
87 |
+
not None, max_lengths must equal to encoder_padding_mask.size(0)
|
88 |
+
|
89 |
+
batch_size: batch size; if encoder_padding_mask is
|
90 |
+
not None, max_lengths must equal to encoder_padding_mask.size(1)
|
91 |
+
|
92 |
+
device: which device to put the result on
|
93 |
+
"""
|
94 |
+
if encoder_padding_mask is None:
|
95 |
+
return torch.Tensor([max_lengths] * batch_size).to(torch.int32).to(device)
|
96 |
+
|
97 |
+
assert encoder_padding_mask.size(0) == max_lengths, "max_lengths does not match"
|
98 |
+
assert encoder_padding_mask.size(1) == batch_size, "batch_size does not match"
|
99 |
+
|
100 |
+
return max_lengths - torch.sum(encoder_padding_mask, dim=0)
|
fairseq/examples/speech_recognition/data/replabels.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the MIT license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
"""
|
9 |
+
Replabel transforms for use with flashlight's ASG criterion.
|
10 |
+
"""
|
11 |
+
|
12 |
+
|
13 |
+
def replabel_symbol(i):
|
14 |
+
"""
|
15 |
+
Replabel symbols used in flashlight, currently just "1", "2", ...
|
16 |
+
This prevents training with numeral tokens, so this might change in the future
|
17 |
+
"""
|
18 |
+
return str(i)
|
19 |
+
|
20 |
+
|
21 |
+
def pack_replabels(tokens, dictionary, max_reps):
|
22 |
+
"""
|
23 |
+
Pack a token sequence so that repeated symbols are replaced by replabels
|
24 |
+
"""
|
25 |
+
if len(tokens) == 0 or max_reps <= 0:
|
26 |
+
return tokens
|
27 |
+
|
28 |
+
replabel_value_to_idx = [0] * (max_reps + 1)
|
29 |
+
for i in range(1, max_reps + 1):
|
30 |
+
replabel_value_to_idx[i] = dictionary.index(replabel_symbol(i))
|
31 |
+
|
32 |
+
result = []
|
33 |
+
prev_token = -1
|
34 |
+
num_reps = 0
|
35 |
+
for token in tokens:
|
36 |
+
if token == prev_token and num_reps < max_reps:
|
37 |
+
num_reps += 1
|
38 |
+
else:
|
39 |
+
if num_reps > 0:
|
40 |
+
result.append(replabel_value_to_idx[num_reps])
|
41 |
+
num_reps = 0
|
42 |
+
result.append(token)
|
43 |
+
prev_token = token
|
44 |
+
if num_reps > 0:
|
45 |
+
result.append(replabel_value_to_idx[num_reps])
|
46 |
+
return result
|
47 |
+
|
48 |
+
|
49 |
+
def unpack_replabels(tokens, dictionary, max_reps):
|
50 |
+
"""
|
51 |
+
Unpack a token sequence so that replabels are replaced by repeated symbols
|
52 |
+
"""
|
53 |
+
if len(tokens) == 0 or max_reps <= 0:
|
54 |
+
return tokens
|
55 |
+
|
56 |
+
replabel_idx_to_value = {}
|
57 |
+
for i in range(1, max_reps + 1):
|
58 |
+
replabel_idx_to_value[dictionary.index(replabel_symbol(i))] = i
|
59 |
+
|
60 |
+
result = []
|
61 |
+
prev_token = -1
|
62 |
+
for token in tokens:
|
63 |
+
try:
|
64 |
+
for _ in range(replabel_idx_to_value[token]):
|
65 |
+
result.append(prev_token)
|
66 |
+
prev_token = -1
|
67 |
+
except KeyError:
|
68 |
+
result.append(token)
|
69 |
+
prev_token = token
|
70 |
+
return result
|
fairseq/examples/speech_recognition/datasets/asr_prep_json.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
8 |
+
|
9 |
+
import argparse
|
10 |
+
import concurrent.futures
|
11 |
+
import json
|
12 |
+
import multiprocessing
|
13 |
+
import os
|
14 |
+
from collections import namedtuple
|
15 |
+
from itertools import chain
|
16 |
+
|
17 |
+
import sentencepiece as spm
|
18 |
+
from fairseq.data import Dictionary
|
19 |
+
|
20 |
+
|
21 |
+
MILLISECONDS_TO_SECONDS = 0.001
|
22 |
+
|
23 |
+
|
24 |
+
def process_sample(aud_path, lable, utt_id, sp, tgt_dict):
|
25 |
+
import torchaudio
|
26 |
+
|
27 |
+
input = {}
|
28 |
+
output = {}
|
29 |
+
si, ei = torchaudio.info(aud_path)
|
30 |
+
input["length_ms"] = int(
|
31 |
+
si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS
|
32 |
+
)
|
33 |
+
input["path"] = aud_path
|
34 |
+
|
35 |
+
token = " ".join(sp.EncodeAsPieces(lable))
|
36 |
+
ids = tgt_dict.encode_line(token, append_eos=False)
|
37 |
+
output["text"] = lable
|
38 |
+
output["token"] = token
|
39 |
+
output["tokenid"] = ", ".join(map(str, [t.tolist() for t in ids]))
|
40 |
+
return {utt_id: {"input": input, "output": output}}
|
41 |
+
|
42 |
+
|
43 |
+
def main():
|
44 |
+
parser = argparse.ArgumentParser()
|
45 |
+
parser.add_argument(
|
46 |
+
"--audio-dirs",
|
47 |
+
nargs="+",
|
48 |
+
default=["-"],
|
49 |
+
required=True,
|
50 |
+
help="input directories with audio files",
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
"--labels",
|
54 |
+
required=True,
|
55 |
+
help="aggregated input labels with format <ID LABEL> per line",
|
56 |
+
type=argparse.FileType("r", encoding="UTF-8"),
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--spm-model",
|
60 |
+
required=True,
|
61 |
+
help="sentencepiece model to use for encoding",
|
62 |
+
type=argparse.FileType("r", encoding="UTF-8"),
|
63 |
+
)
|
64 |
+
parser.add_argument(
|
65 |
+
"--dictionary",
|
66 |
+
required=True,
|
67 |
+
help="file to load fairseq dictionary from",
|
68 |
+
type=argparse.FileType("r", encoding="UTF-8"),
|
69 |
+
)
|
70 |
+
parser.add_argument("--audio-format", choices=["flac", "wav"], default="wav")
|
71 |
+
parser.add_argument(
|
72 |
+
"--output",
|
73 |
+
required=True,
|
74 |
+
type=argparse.FileType("w"),
|
75 |
+
help="path to save json output",
|
76 |
+
)
|
77 |
+
args = parser.parse_args()
|
78 |
+
|
79 |
+
sp = spm.SentencePieceProcessor()
|
80 |
+
sp.Load(args.spm_model.name)
|
81 |
+
|
82 |
+
tgt_dict = Dictionary.load(args.dictionary)
|
83 |
+
|
84 |
+
labels = {}
|
85 |
+
for line in args.labels:
|
86 |
+
(utt_id, label) = line.split(" ", 1)
|
87 |
+
labels[utt_id] = label
|
88 |
+
if len(labels) == 0:
|
89 |
+
raise Exception("No labels found in ", args.labels_path)
|
90 |
+
|
91 |
+
Sample = namedtuple("Sample", "aud_path utt_id")
|
92 |
+
samples = []
|
93 |
+
for path, _, files in chain.from_iterable(
|
94 |
+
os.walk(path) for path in args.audio_dirs
|
95 |
+
):
|
96 |
+
for f in files:
|
97 |
+
if f.endswith(args.audio_format):
|
98 |
+
if len(os.path.splitext(f)) != 2:
|
99 |
+
raise Exception("Expect <utt_id.extension> file name. Got: ", f)
|
100 |
+
utt_id = os.path.splitext(f)[0]
|
101 |
+
if utt_id not in labels:
|
102 |
+
continue
|
103 |
+
samples.append(Sample(os.path.join(path, f), utt_id))
|
104 |
+
|
105 |
+
utts = {}
|
106 |
+
num_cpu = multiprocessing.cpu_count()
|
107 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpu) as executor:
|
108 |
+
future_to_sample = {
|
109 |
+
executor.submit(
|
110 |
+
process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict
|
111 |
+
): s
|
112 |
+
for s in samples
|
113 |
+
}
|
114 |
+
for future in concurrent.futures.as_completed(future_to_sample):
|
115 |
+
try:
|
116 |
+
data = future.result()
|
117 |
+
except Exception as exc:
|
118 |
+
print("generated an exception: ", exc)
|
119 |
+
else:
|
120 |
+
utts.update(data)
|
121 |
+
json.dump({"utts": utts}, args.output, indent=4)
|
122 |
+
|
123 |
+
|
124 |
+
if __name__ == "__main__":
|
125 |
+
main()
|
fairseq/examples/speech_recognition/datasets/prepare-librispeech.sh
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Prepare librispeech dataset
|
8 |
+
|
9 |
+
base_url=www.openslr.org/resources/12
|
10 |
+
train_dir=train_960
|
11 |
+
|
12 |
+
if [ "$#" -ne 2 ]; then
|
13 |
+
echo "Usage: $0 <download_dir> <out_dir>"
|
14 |
+
echo "e.g.: $0 /tmp/librispeech_raw/ ~/data/librispeech_final"
|
15 |
+
exit 1
|
16 |
+
fi
|
17 |
+
|
18 |
+
download_dir=${1%/}
|
19 |
+
out_dir=${2%/}
|
20 |
+
|
21 |
+
fairseq_root=~/fairseq-py/
|
22 |
+
mkdir -p ${out_dir}
|
23 |
+
cd ${out_dir} || exit
|
24 |
+
|
25 |
+
nbpe=5000
|
26 |
+
bpemode=unigram
|
27 |
+
|
28 |
+
if [ ! -d "$fairseq_root" ]; then
|
29 |
+
echo "$0: Please set correct fairseq_root"
|
30 |
+
exit 1
|
31 |
+
fi
|
32 |
+
|
33 |
+
echo "Data Download"
|
34 |
+
for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
|
35 |
+
url=$base_url/$part.tar.gz
|
36 |
+
if ! wget -P $download_dir $url; then
|
37 |
+
echo "$0: wget failed for $url"
|
38 |
+
exit 1
|
39 |
+
fi
|
40 |
+
if ! tar -C $download_dir -xvzf $download_dir/$part.tar.gz; then
|
41 |
+
echo "$0: error un-tarring archive $download_dir/$part.tar.gz"
|
42 |
+
exit 1
|
43 |
+
fi
|
44 |
+
done
|
45 |
+
|
46 |
+
echo "Merge all train packs into one"
|
47 |
+
mkdir -p ${download_dir}/LibriSpeech/${train_dir}/
|
48 |
+
for part in train-clean-100 train-clean-360 train-other-500; do
|
49 |
+
mv ${download_dir}/LibriSpeech/${part}/* $download_dir/LibriSpeech/${train_dir}/
|
50 |
+
done
|
51 |
+
echo "Merge train text"
|
52 |
+
find ${download_dir}/LibriSpeech/${train_dir}/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/${train_dir}/text
|
53 |
+
|
54 |
+
# Use combined dev-clean and dev-other as validation set
|
55 |
+
find ${download_dir}/LibriSpeech/dev-clean/ ${download_dir}/LibriSpeech/dev-other/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/valid_text
|
56 |
+
find ${download_dir}/LibriSpeech/test-clean/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/test-clean/text
|
57 |
+
find ${download_dir}/LibriSpeech/test-other/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/test-other/text
|
58 |
+
|
59 |
+
|
60 |
+
dict=data/lang_char/${train_dir}_${bpemode}${nbpe}_units.txt
|
61 |
+
encoded=data/lang_char/${train_dir}_${bpemode}${nbpe}_encoded.txt
|
62 |
+
fairseq_dict=data/lang_char/${train_dir}_${bpemode}${nbpe}_fairseq_dict.txt
|
63 |
+
bpemodel=data/lang_char/${train_dir}_${bpemode}${nbpe}
|
64 |
+
echo "dictionary: ${dict}"
|
65 |
+
echo "Dictionary preparation"
|
66 |
+
mkdir -p data/lang_char/
|
67 |
+
echo "<unk> 3" > ${dict}
|
68 |
+
echo "</s> 2" >> ${dict}
|
69 |
+
echo "<pad> 1" >> ${dict}
|
70 |
+
cut -f 2- -d" " ${download_dir}/LibriSpeech/${train_dir}/text > data/lang_char/input.txt
|
71 |
+
spm_train --input=data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000 --unk_id=3 --eos_id=2 --pad_id=1 --bos_id=-1 --character_coverage=1
|
72 |
+
spm_encode --model=${bpemodel}.model --output_format=piece < data/lang_char/input.txt > ${encoded}
|
73 |
+
cat ${encoded} | tr ' ' '\n' | sort | uniq | awk '{print $0 " " NR+3}' >> ${dict}
|
74 |
+
cat ${encoded} | tr ' ' '\n' | sort | uniq -c | awk '{print $2 " " $1}' > ${fairseq_dict}
|
75 |
+
wc -l ${dict}
|
76 |
+
|
77 |
+
echo "Prepare train and test jsons"
|
78 |
+
for part in train_960 test-other test-clean; do
|
79 |
+
python ${fairseq_root}/examples/speech_recognition/datasets/asr_prep_json.py --audio-dirs ${download_dir}/LibriSpeech/${part} --labels ${download_dir}/LibriSpeech/${part}/text --spm-model ${bpemodel}.model --audio-format flac --dictionary ${fairseq_dict} --output ${part}.json
|
80 |
+
done
|
81 |
+
# fairseq expects to find train.json and valid.json during training
|
82 |
+
mv train_960.json train.json
|
83 |
+
|
84 |
+
echo "Prepare valid json"
|
85 |
+
python ${fairseq_root}/examples/speech_recognition/datasets/asr_prep_json.py --audio-dirs ${download_dir}/LibriSpeech/dev-clean ${download_dir}/LibriSpeech/dev-other --labels ${download_dir}/LibriSpeech/valid_text --spm-model ${bpemodel}.model --audio-format flac --dictionary ${fairseq_dict} --output valid.json
|
86 |
+
|
87 |
+
cp ${fairseq_dict} ./dict.txt
|
88 |
+
cp ${bpemodel}.model ./spm.model
|
fairseq/examples/speech_recognition/infer.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3 -u
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Run inference for pre-processed data with a trained model.
|
9 |
+
"""
|
10 |
+
|
11 |
+
import ast
|
12 |
+
import logging
|
13 |
+
import math
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
|
17 |
+
import editdistance
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
|
21 |
+
from fairseq.data.data_utils import post_process
|
22 |
+
from fairseq.logging.meters import StopwatchMeter, TimeMeter
|
23 |
+
|
24 |
+
|
25 |
+
logging.basicConfig()
|
26 |
+
logging.root.setLevel(logging.INFO)
|
27 |
+
logging.basicConfig(level=logging.INFO)
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
def add_asr_eval_argument(parser):
|
32 |
+
parser.add_argument("--kspmodel", default=None, help="sentence piece model")
|
33 |
+
parser.add_argument(
|
34 |
+
"--wfstlm", default=None, help="wfstlm on dictonary output units"
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--rnnt_decoding_type",
|
38 |
+
default="greedy",
|
39 |
+
help="wfstlm on dictonary\
|
40 |
+
output units",
|
41 |
+
)
|
42 |
+
try:
|
43 |
+
parser.add_argument(
|
44 |
+
"--lm-weight",
|
45 |
+
"--lm_weight",
|
46 |
+
type=float,
|
47 |
+
default=0.2,
|
48 |
+
help="weight for lm while interpolating with neural score",
|
49 |
+
)
|
50 |
+
except:
|
51 |
+
pass
|
52 |
+
parser.add_argument(
|
53 |
+
"--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
"--w2l-decoder",
|
57 |
+
choices=["viterbi", "kenlm", "fairseqlm"],
|
58 |
+
help="use a w2l decoder",
|
59 |
+
)
|
60 |
+
parser.add_argument("--lexicon", help="lexicon for w2l decoder")
|
61 |
+
parser.add_argument("--unit-lm", action="store_true", help="if using a unit lm")
|
62 |
+
parser.add_argument("--kenlm-model", "--lm-model", help="lm model for w2l decoder")
|
63 |
+
parser.add_argument("--beam-threshold", type=float, default=25.0)
|
64 |
+
parser.add_argument("--beam-size-token", type=float, default=100)
|
65 |
+
parser.add_argument("--word-score", type=float, default=1.0)
|
66 |
+
parser.add_argument("--unk-weight", type=float, default=-math.inf)
|
67 |
+
parser.add_argument("--sil-weight", type=float, default=0.0)
|
68 |
+
parser.add_argument(
|
69 |
+
"--dump-emissions",
|
70 |
+
type=str,
|
71 |
+
default=None,
|
72 |
+
help="if present, dumps emissions into this file and exits",
|
73 |
+
)
|
74 |
+
parser.add_argument(
|
75 |
+
"--dump-features",
|
76 |
+
type=str,
|
77 |
+
default=None,
|
78 |
+
help="if present, dumps features into this file and exits",
|
79 |
+
)
|
80 |
+
parser.add_argument(
|
81 |
+
"--load-emissions",
|
82 |
+
type=str,
|
83 |
+
default=None,
|
84 |
+
help="if present, loads emissions from this file",
|
85 |
+
)
|
86 |
+
return parser
|
87 |
+
|
88 |
+
|
89 |
+
def check_args(args):
|
90 |
+
# assert args.path is not None, "--path required for generation!"
|
91 |
+
# assert args.results_path is not None, "--results_path required for generation!"
|
92 |
+
assert (
|
93 |
+
not args.sampling or args.nbest == args.beam
|
94 |
+
), "--sampling requires --nbest to be equal to --beam"
|
95 |
+
assert (
|
96 |
+
args.replace_unk is None or args.raw_text
|
97 |
+
), "--replace-unk requires a raw text dataset (--raw-text)"
|
98 |
+
|
99 |
+
|
100 |
+
def get_dataset_itr(args, task, models):
|
101 |
+
return task.get_batch_iterator(
|
102 |
+
dataset=task.dataset(args.gen_subset),
|
103 |
+
max_tokens=args.max_tokens,
|
104 |
+
max_sentences=args.batch_size,
|
105 |
+
max_positions=(sys.maxsize, sys.maxsize),
|
106 |
+
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
|
107 |
+
required_batch_size_multiple=args.required_batch_size_multiple,
|
108 |
+
num_shards=args.num_shards,
|
109 |
+
shard_id=args.shard_id,
|
110 |
+
num_workers=args.num_workers,
|
111 |
+
data_buffer_size=args.data_buffer_size,
|
112 |
+
).next_epoch_itr(shuffle=False)
|
113 |
+
|
114 |
+
|
115 |
+
def process_predictions(
|
116 |
+
args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id
|
117 |
+
):
|
118 |
+
for hypo in hypos[: min(len(hypos), args.nbest)]:
|
119 |
+
hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
|
120 |
+
|
121 |
+
if "words" in hypo:
|
122 |
+
hyp_words = " ".join(hypo["words"])
|
123 |
+
else:
|
124 |
+
hyp_words = post_process(hyp_pieces, args.post_process)
|
125 |
+
|
126 |
+
if res_files is not None:
|
127 |
+
print(
|
128 |
+
"{} ({}-{})".format(hyp_pieces, speaker, id),
|
129 |
+
file=res_files["hypo.units"],
|
130 |
+
)
|
131 |
+
print(
|
132 |
+
"{} ({}-{})".format(hyp_words, speaker, id),
|
133 |
+
file=res_files["hypo.words"],
|
134 |
+
)
|
135 |
+
|
136 |
+
tgt_pieces = tgt_dict.string(target_tokens)
|
137 |
+
tgt_words = post_process(tgt_pieces, args.post_process)
|
138 |
+
|
139 |
+
if res_files is not None:
|
140 |
+
print(
|
141 |
+
"{} ({}-{})".format(tgt_pieces, speaker, id),
|
142 |
+
file=res_files["ref.units"],
|
143 |
+
)
|
144 |
+
print(
|
145 |
+
"{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"]
|
146 |
+
)
|
147 |
+
|
148 |
+
if not args.quiet:
|
149 |
+
logger.info("HYPO:" + hyp_words)
|
150 |
+
logger.info("TARGET:" + tgt_words)
|
151 |
+
logger.info("___________________")
|
152 |
+
|
153 |
+
hyp_words = hyp_words.split()
|
154 |
+
tgt_words = tgt_words.split()
|
155 |
+
return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
|
156 |
+
|
157 |
+
|
158 |
+
def prepare_result_files(args):
|
159 |
+
def get_res_file(file_prefix):
|
160 |
+
if args.num_shards > 1:
|
161 |
+
file_prefix = f"{args.shard_id}_{file_prefix}"
|
162 |
+
path = os.path.join(
|
163 |
+
args.results_path,
|
164 |
+
"{}-{}-{}.txt".format(
|
165 |
+
file_prefix, os.path.basename(args.path), args.gen_subset
|
166 |
+
),
|
167 |
+
)
|
168 |
+
return open(path, "w", buffering=1)
|
169 |
+
|
170 |
+
if not args.results_path:
|
171 |
+
return None
|
172 |
+
|
173 |
+
return {
|
174 |
+
"hypo.words": get_res_file("hypo.word"),
|
175 |
+
"hypo.units": get_res_file("hypo.units"),
|
176 |
+
"ref.words": get_res_file("ref.word"),
|
177 |
+
"ref.units": get_res_file("ref.units"),
|
178 |
+
}
|
179 |
+
|
180 |
+
|
181 |
+
def optimize_models(args, use_cuda, models):
|
182 |
+
"""Optimize ensemble for generation"""
|
183 |
+
for model in models:
|
184 |
+
model.make_generation_fast_(
|
185 |
+
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
|
186 |
+
need_attn=args.print_alignment,
|
187 |
+
)
|
188 |
+
if args.fp16:
|
189 |
+
model.half()
|
190 |
+
if use_cuda:
|
191 |
+
model.cuda()
|
192 |
+
|
193 |
+
|
194 |
+
def apply_half(t):
|
195 |
+
if t.dtype is torch.float32:
|
196 |
+
return t.to(dtype=torch.half)
|
197 |
+
return t
|
198 |
+
|
199 |
+
|
200 |
+
class ExistingEmissionsDecoder(object):
|
201 |
+
def __init__(self, decoder, emissions):
|
202 |
+
self.decoder = decoder
|
203 |
+
self.emissions = emissions
|
204 |
+
|
205 |
+
def generate(self, models, sample, **unused):
|
206 |
+
ids = sample["id"].cpu().numpy()
|
207 |
+
try:
|
208 |
+
emissions = np.stack(self.emissions[ids])
|
209 |
+
except:
|
210 |
+
print([x.shape for x in self.emissions[ids]])
|
211 |
+
raise Exception("invalid sizes")
|
212 |
+
emissions = torch.from_numpy(emissions)
|
213 |
+
return self.decoder.decode(emissions)
|
214 |
+
|
215 |
+
|
216 |
+
def main(args, task=None, model_state=None):
|
217 |
+
check_args(args)
|
218 |
+
|
219 |
+
use_fp16 = args.fp16
|
220 |
+
if args.max_tokens is None and args.batch_size is None:
|
221 |
+
args.max_tokens = 4000000
|
222 |
+
logger.info(args)
|
223 |
+
|
224 |
+
use_cuda = torch.cuda.is_available() and not args.cpu
|
225 |
+
|
226 |
+
logger.info("| decoding with criterion {}".format(args.criterion))
|
227 |
+
|
228 |
+
task = tasks.setup_task(args)
|
229 |
+
|
230 |
+
# Load ensemble
|
231 |
+
if args.load_emissions:
|
232 |
+
models, criterions = [], []
|
233 |
+
task.load_dataset(args.gen_subset)
|
234 |
+
else:
|
235 |
+
logger.info("| loading model(s) from {}".format(args.path))
|
236 |
+
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
237 |
+
utils.split_paths(args.path, separator="\\"),
|
238 |
+
arg_overrides=ast.literal_eval(args.model_overrides),
|
239 |
+
task=task,
|
240 |
+
suffix=args.checkpoint_suffix,
|
241 |
+
strict=(args.checkpoint_shard_count == 1),
|
242 |
+
num_shards=args.checkpoint_shard_count,
|
243 |
+
state=model_state,
|
244 |
+
)
|
245 |
+
optimize_models(args, use_cuda, models)
|
246 |
+
task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task)
|
247 |
+
|
248 |
+
|
249 |
+
# Set dictionary
|
250 |
+
tgt_dict = task.target_dictionary
|
251 |
+
|
252 |
+
logger.info(
|
253 |
+
"| {} {} {} examples".format(
|
254 |
+
args.data, args.gen_subset, len(task.dataset(args.gen_subset))
|
255 |
+
)
|
256 |
+
)
|
257 |
+
|
258 |
+
# hack to pass transitions to W2lDecoder
|
259 |
+
if args.criterion == "asg_loss":
|
260 |
+
raise NotImplementedError("asg_loss is currently not supported")
|
261 |
+
# trans = criterions[0].asg.trans.data
|
262 |
+
# args.asg_transitions = torch.flatten(trans).tolist()
|
263 |
+
|
264 |
+
# Load dataset (possibly sharded)
|
265 |
+
itr = get_dataset_itr(args, task, models)
|
266 |
+
|
267 |
+
# Initialize generator
|
268 |
+
gen_timer = StopwatchMeter()
|
269 |
+
|
270 |
+
def build_generator(args):
|
271 |
+
w2l_decoder = getattr(args, "w2l_decoder", None)
|
272 |
+
if w2l_decoder == "viterbi":
|
273 |
+
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
|
274 |
+
|
275 |
+
return W2lViterbiDecoder(args, task.target_dictionary)
|
276 |
+
elif w2l_decoder == "kenlm":
|
277 |
+
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
|
278 |
+
|
279 |
+
return W2lKenLMDecoder(args, task.target_dictionary)
|
280 |
+
elif w2l_decoder == "fairseqlm":
|
281 |
+
from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder
|
282 |
+
|
283 |
+
return W2lFairseqLMDecoder(args, task.target_dictionary)
|
284 |
+
else:
|
285 |
+
print(
|
286 |
+
"only flashlight decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment"
|
287 |
+
)
|
288 |
+
|
289 |
+
# please do not touch this unless you test both generate.py and infer.py with audio_pretraining task
|
290 |
+
generator = build_generator(args)
|
291 |
+
|
292 |
+
if args.load_emissions:
|
293 |
+
generator = ExistingEmissionsDecoder(
|
294 |
+
generator, np.load(args.load_emissions, allow_pickle=True)
|
295 |
+
)
|
296 |
+
logger.info("loaded emissions from " + args.load_emissions)
|
297 |
+
|
298 |
+
num_sentences = 0
|
299 |
+
|
300 |
+
if args.results_path is not None and not os.path.exists(args.results_path):
|
301 |
+
os.makedirs(args.results_path)
|
302 |
+
|
303 |
+
max_source_pos = (
|
304 |
+
utils.resolve_max_positions(
|
305 |
+
task.max_positions(), *[model.max_positions() for model in models]
|
306 |
+
),
|
307 |
+
)
|
308 |
+
|
309 |
+
if max_source_pos is not None:
|
310 |
+
max_source_pos = max_source_pos[0]
|
311 |
+
if max_source_pos is not None:
|
312 |
+
max_source_pos = max_source_pos[0] - 1
|
313 |
+
|
314 |
+
if args.dump_emissions:
|
315 |
+
emissions = {}
|
316 |
+
if args.dump_features:
|
317 |
+
features = {}
|
318 |
+
models[0].bert.proj = None
|
319 |
+
else:
|
320 |
+
res_files = prepare_result_files(args)
|
321 |
+
errs_t = 0
|
322 |
+
lengths_t = 0
|
323 |
+
with progress_bar.build_progress_bar(args, itr) as t:
|
324 |
+
wps_meter = TimeMeter()
|
325 |
+
for sample in t:
|
326 |
+
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
327 |
+
if use_fp16:
|
328 |
+
sample = utils.apply_to_sample(apply_half, sample)
|
329 |
+
if "net_input" not in sample:
|
330 |
+
continue
|
331 |
+
|
332 |
+
prefix_tokens = None
|
333 |
+
if args.prefix_size > 0:
|
334 |
+
prefix_tokens = sample["target"][:, : args.prefix_size]
|
335 |
+
|
336 |
+
gen_timer.start()
|
337 |
+
if args.dump_emissions:
|
338 |
+
with torch.no_grad():
|
339 |
+
encoder_out = models[0](**sample["net_input"])
|
340 |
+
emm = models[0].get_normalized_probs(encoder_out, log_probs=True)
|
341 |
+
emm = emm.transpose(0, 1).cpu().numpy()
|
342 |
+
for i, id in enumerate(sample["id"]):
|
343 |
+
emissions[id.item()] = emm[i]
|
344 |
+
continue
|
345 |
+
elif args.dump_features:
|
346 |
+
with torch.no_grad():
|
347 |
+
encoder_out = models[0](**sample["net_input"])
|
348 |
+
feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy()
|
349 |
+
for i, id in enumerate(sample["id"]):
|
350 |
+
padding = (
|
351 |
+
encoder_out["encoder_padding_mask"][i].cpu().numpy()
|
352 |
+
if encoder_out["encoder_padding_mask"] is not None
|
353 |
+
else None
|
354 |
+
)
|
355 |
+
features[id.item()] = (feat[i], padding)
|
356 |
+
continue
|
357 |
+
hypos = task.inference_step(generator, models, sample, prefix_tokens)
|
358 |
+
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
|
359 |
+
gen_timer.stop(num_generated_tokens)
|
360 |
+
|
361 |
+
for i, sample_id in enumerate(sample["id"].tolist()):
|
362 |
+
speaker = None
|
363 |
+
# id = task.dataset(args.gen_subset).ids[int(sample_id)]
|
364 |
+
id = sample_id
|
365 |
+
toks = (
|
366 |
+
sample["target"][i, :]
|
367 |
+
if "target_label" not in sample
|
368 |
+
else sample["target_label"][i, :]
|
369 |
+
)
|
370 |
+
target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
|
371 |
+
# Process top predictions
|
372 |
+
errs, length = process_predictions(
|
373 |
+
args,
|
374 |
+
hypos[i],
|
375 |
+
None,
|
376 |
+
tgt_dict,
|
377 |
+
target_tokens,
|
378 |
+
res_files,
|
379 |
+
speaker,
|
380 |
+
id,
|
381 |
+
)
|
382 |
+
errs_t += errs
|
383 |
+
lengths_t += length
|
384 |
+
|
385 |
+
wps_meter.update(num_generated_tokens)
|
386 |
+
t.log({"wps": round(wps_meter.avg)})
|
387 |
+
num_sentences += (
|
388 |
+
sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
|
389 |
+
)
|
390 |
+
|
391 |
+
wer = None
|
392 |
+
if args.dump_emissions:
|
393 |
+
emm_arr = []
|
394 |
+
for i in range(len(emissions)):
|
395 |
+
emm_arr.append(emissions[i])
|
396 |
+
np.save(args.dump_emissions, emm_arr)
|
397 |
+
logger.info(f"saved {len(emissions)} emissions to {args.dump_emissions}")
|
398 |
+
elif args.dump_features:
|
399 |
+
feat_arr = []
|
400 |
+
for i in range(len(features)):
|
401 |
+
feat_arr.append(features[i])
|
402 |
+
np.save(args.dump_features, feat_arr)
|
403 |
+
logger.info(f"saved {len(features)} emissions to {args.dump_features}")
|
404 |
+
else:
|
405 |
+
if lengths_t > 0:
|
406 |
+
wer = errs_t * 100.0 / lengths_t
|
407 |
+
logger.info(f"WER: {wer}")
|
408 |
+
|
409 |
+
logger.info(
|
410 |
+
"| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
|
411 |
+
"sentences/s, {:.2f} tokens/s)".format(
|
412 |
+
num_sentences,
|
413 |
+
gen_timer.n,
|
414 |
+
gen_timer.sum,
|
415 |
+
num_sentences / gen_timer.sum,
|
416 |
+
1.0 / gen_timer.avg,
|
417 |
+
)
|
418 |
+
)
|
419 |
+
logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam))
|
420 |
+
return task, wer
|
421 |
+
|
422 |
+
|
423 |
+
def make_parser():
|
424 |
+
parser = options.get_generation_parser()
|
425 |
+
parser = add_asr_eval_argument(parser)
|
426 |
+
return parser
|
427 |
+
|
428 |
+
|
429 |
+
def cli_main():
|
430 |
+
parser = make_parser()
|
431 |
+
args = options.parse_args_and_arch(parser)
|
432 |
+
main(args)
|
433 |
+
|
434 |
+
|
435 |
+
if __name__ == "__main__":
|
436 |
+
cli_main()
|
fairseq/examples/speech_recognition/kaldi/__init__.py
ADDED
File without changes
|
fairseq/examples/speech_recognition/kaldi/add-self-loop-simple.cc
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
*
|
4 |
+
* This source code is licensed under the MIT license found in the
|
5 |
+
* LICENSE file in the root directory of this source tree.
|
6 |
+
*/
|
7 |
+
|
8 |
+
#include <iostream>
|
9 |
+
#include "fstext/fstext-lib.h" // @manual
|
10 |
+
#include "util/common-utils.h" // @manual
|
11 |
+
|
12 |
+
/*
|
13 |
+
* This program is to modify a FST without self-loop by:
|
14 |
+
* for each incoming arc with non-eps input symbol, add a self-loop arc
|
15 |
+
* with that non-eps symbol as input and eps as output.
|
16 |
+
*
|
17 |
+
* This is to make sure the resultant FST can do deduplication for repeated
|
18 |
+
* symbols, which is very common in acoustic model
|
19 |
+
*
|
20 |
+
*/
|
21 |
+
namespace {
|
22 |
+
int32 AddSelfLoopsSimple(fst::StdVectorFst* fst) {
|
23 |
+
typedef fst::MutableArcIterator<fst::StdVectorFst> IterType;
|
24 |
+
|
25 |
+
int32 num_states_before = fst->NumStates();
|
26 |
+
fst::MakePrecedingInputSymbolsSame(false, fst);
|
27 |
+
int32 num_states_after = fst->NumStates();
|
28 |
+
KALDI_LOG << "There are " << num_states_before
|
29 |
+
<< " states in the original FST; "
|
30 |
+
<< " after MakePrecedingInputSymbolsSame, there are "
|
31 |
+
<< num_states_after << " states " << std::endl;
|
32 |
+
|
33 |
+
auto weight_one = fst::StdArc::Weight::One();
|
34 |
+
|
35 |
+
int32 num_arc_added = 0;
|
36 |
+
|
37 |
+
fst::StdArc self_loop_arc;
|
38 |
+
self_loop_arc.weight = weight_one;
|
39 |
+
|
40 |
+
int32 num_states = fst->NumStates();
|
41 |
+
std::vector<std::set<int32>> incoming_non_eps_label_per_state(num_states);
|
42 |
+
|
43 |
+
for (int32 state = 0; state < num_states; state++) {
|
44 |
+
for (IterType aiter(fst, state); !aiter.Done(); aiter.Next()) {
|
45 |
+
fst::StdArc arc(aiter.Value());
|
46 |
+
if (arc.ilabel != 0) {
|
47 |
+
incoming_non_eps_label_per_state[arc.nextstate].insert(arc.ilabel);
|
48 |
+
}
|
49 |
+
}
|
50 |
+
}
|
51 |
+
|
52 |
+
for (int32 state = 0; state < num_states; state++) {
|
53 |
+
if (!incoming_non_eps_label_per_state[state].empty()) {
|
54 |
+
auto& ilabel_set = incoming_non_eps_label_per_state[state];
|
55 |
+
for (auto it = ilabel_set.begin(); it != ilabel_set.end(); it++) {
|
56 |
+
self_loop_arc.ilabel = *it;
|
57 |
+
self_loop_arc.olabel = 0;
|
58 |
+
self_loop_arc.nextstate = state;
|
59 |
+
fst->AddArc(state, self_loop_arc);
|
60 |
+
num_arc_added++;
|
61 |
+
}
|
62 |
+
}
|
63 |
+
}
|
64 |
+
return num_arc_added;
|
65 |
+
}
|
66 |
+
|
67 |
+
void print_usage() {
|
68 |
+
std::cout << "add-self-loop-simple usage:\n"
|
69 |
+
"\tadd-self-loop-simple <in-fst> <out-fst> \n";
|
70 |
+
}
|
71 |
+
} // namespace
|
72 |
+
|
73 |
+
int main(int argc, char** argv) {
|
74 |
+
if (argc != 3) {
|
75 |
+
print_usage();
|
76 |
+
exit(1);
|
77 |
+
}
|
78 |
+
|
79 |
+
auto input = argv[1];
|
80 |
+
auto output = argv[2];
|
81 |
+
|
82 |
+
auto fst = fst::ReadFstKaldi(input);
|
83 |
+
auto num_states = fst->NumStates();
|
84 |
+
KALDI_LOG << "Loading FST from " << input << " with " << num_states
|
85 |
+
<< " states." << std::endl;
|
86 |
+
|
87 |
+
int32 num_arc_added = AddSelfLoopsSimple(fst);
|
88 |
+
KALDI_LOG << "Adding " << num_arc_added << " self-loop arcs " << std::endl;
|
89 |
+
|
90 |
+
fst::WriteFstKaldi(*fst, std::string(output));
|
91 |
+
KALDI_LOG << "Writing FST to " << output << std::endl;
|
92 |
+
|
93 |
+
delete fst;
|
94 |
+
}
|
fairseq/examples/speech_recognition/kaldi/config/kaldi_initializer.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
data_dir: ???
|
4 |
+
fst_dir: ???
|
5 |
+
in_labels: ???
|
6 |
+
kaldi_root: ???
|
7 |
+
lm_arpa: ???
|
8 |
+
blank_symbol: <s>
|
fairseq/examples/speech_recognition/kaldi/kaldi_decoder.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the MIT license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
from concurrent.futures import ThreadPoolExecutor
|
9 |
+
import logging
|
10 |
+
from omegaconf import MISSING
|
11 |
+
import os
|
12 |
+
import torch
|
13 |
+
from typing import Optional
|
14 |
+
import warnings
|
15 |
+
|
16 |
+
|
17 |
+
from dataclasses import dataclass
|
18 |
+
from fairseq.dataclass import FairseqDataclass
|
19 |
+
from .kaldi_initializer import KaldiInitializerConfig, initalize_kaldi
|
20 |
+
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class KaldiDecoderConfig(FairseqDataclass):
|
27 |
+
hlg_graph_path: Optional[str] = None
|
28 |
+
output_dict: str = MISSING
|
29 |
+
|
30 |
+
kaldi_initializer_config: Optional[KaldiInitializerConfig] = None
|
31 |
+
|
32 |
+
acoustic_scale: float = 0.5
|
33 |
+
max_active: int = 10000
|
34 |
+
beam_delta: float = 0.5
|
35 |
+
hash_ratio: float = 2.0
|
36 |
+
|
37 |
+
is_lattice: bool = False
|
38 |
+
lattice_beam: float = 10.0
|
39 |
+
prune_interval: int = 25
|
40 |
+
determinize_lattice: bool = True
|
41 |
+
prune_scale: float = 0.1
|
42 |
+
max_mem: int = 0
|
43 |
+
phone_determinize: bool = True
|
44 |
+
word_determinize: bool = True
|
45 |
+
minimize: bool = True
|
46 |
+
|
47 |
+
num_threads: int = 1
|
48 |
+
|
49 |
+
|
50 |
+
class KaldiDecoder(object):
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
cfg: KaldiDecoderConfig,
|
54 |
+
beam: int,
|
55 |
+
nbest: int = 1,
|
56 |
+
):
|
57 |
+
try:
|
58 |
+
from kaldi.asr import FasterRecognizer, LatticeFasterRecognizer
|
59 |
+
from kaldi.base import set_verbose_level
|
60 |
+
from kaldi.decoder import (
|
61 |
+
FasterDecoder,
|
62 |
+
FasterDecoderOptions,
|
63 |
+
LatticeFasterDecoder,
|
64 |
+
LatticeFasterDecoderOptions,
|
65 |
+
)
|
66 |
+
from kaldi.lat.functions import DeterminizeLatticePhonePrunedOptions
|
67 |
+
from kaldi.fstext import read_fst_kaldi, SymbolTable
|
68 |
+
except:
|
69 |
+
warnings.warn(
|
70 |
+
"pykaldi is required for this functionality. Please install from https://github.com/pykaldi/pykaldi"
|
71 |
+
)
|
72 |
+
|
73 |
+
# set_verbose_level(2)
|
74 |
+
|
75 |
+
self.acoustic_scale = cfg.acoustic_scale
|
76 |
+
self.nbest = nbest
|
77 |
+
|
78 |
+
if cfg.hlg_graph_path is None:
|
79 |
+
assert (
|
80 |
+
cfg.kaldi_initializer_config is not None
|
81 |
+
), "Must provide hlg graph path or kaldi initializer config"
|
82 |
+
cfg.hlg_graph_path = initalize_kaldi(cfg.kaldi_initializer_config)
|
83 |
+
|
84 |
+
assert os.path.exists(cfg.hlg_graph_path), cfg.hlg_graph_path
|
85 |
+
|
86 |
+
if cfg.is_lattice:
|
87 |
+
self.dec_cls = LatticeFasterDecoder
|
88 |
+
opt_cls = LatticeFasterDecoderOptions
|
89 |
+
self.rec_cls = LatticeFasterRecognizer
|
90 |
+
else:
|
91 |
+
assert self.nbest == 1, "nbest > 1 requires lattice decoder"
|
92 |
+
self.dec_cls = FasterDecoder
|
93 |
+
opt_cls = FasterDecoderOptions
|
94 |
+
self.rec_cls = FasterRecognizer
|
95 |
+
|
96 |
+
self.decoder_options = opt_cls()
|
97 |
+
self.decoder_options.beam = beam
|
98 |
+
self.decoder_options.max_active = cfg.max_active
|
99 |
+
self.decoder_options.beam_delta = cfg.beam_delta
|
100 |
+
self.decoder_options.hash_ratio = cfg.hash_ratio
|
101 |
+
|
102 |
+
if cfg.is_lattice:
|
103 |
+
self.decoder_options.lattice_beam = cfg.lattice_beam
|
104 |
+
self.decoder_options.prune_interval = cfg.prune_interval
|
105 |
+
self.decoder_options.determinize_lattice = cfg.determinize_lattice
|
106 |
+
self.decoder_options.prune_scale = cfg.prune_scale
|
107 |
+
det_opts = DeterminizeLatticePhonePrunedOptions()
|
108 |
+
det_opts.max_mem = cfg.max_mem
|
109 |
+
det_opts.phone_determinize = cfg.phone_determinize
|
110 |
+
det_opts.word_determinize = cfg.word_determinize
|
111 |
+
det_opts.minimize = cfg.minimize
|
112 |
+
self.decoder_options.det_opts = det_opts
|
113 |
+
|
114 |
+
self.output_symbols = {}
|
115 |
+
with open(cfg.output_dict, "r") as f:
|
116 |
+
for line in f:
|
117 |
+
items = line.rstrip().split()
|
118 |
+
assert len(items) == 2
|
119 |
+
self.output_symbols[int(items[1])] = items[0]
|
120 |
+
|
121 |
+
logger.info(f"Loading FST from {cfg.hlg_graph_path}")
|
122 |
+
self.fst = read_fst_kaldi(cfg.hlg_graph_path)
|
123 |
+
self.symbol_table = SymbolTable.read_text(cfg.output_dict)
|
124 |
+
|
125 |
+
self.executor = ThreadPoolExecutor(max_workers=cfg.num_threads)
|
126 |
+
|
127 |
+
def generate(self, models, sample, **unused):
|
128 |
+
"""Generate a batch of inferences."""
|
129 |
+
# model.forward normally channels prev_output_tokens into the decoder
|
130 |
+
# separately, but SequenceGenerator directly calls model.encoder
|
131 |
+
encoder_input = {
|
132 |
+
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
|
133 |
+
}
|
134 |
+
emissions, padding = self.get_emissions(models, encoder_input)
|
135 |
+
return self.decode(emissions, padding)
|
136 |
+
|
137 |
+
def get_emissions(self, models, encoder_input):
|
138 |
+
"""Run encoder and normalize emissions"""
|
139 |
+
model = models[0]
|
140 |
+
|
141 |
+
all_encoder_out = [m(**encoder_input) for m in models]
|
142 |
+
|
143 |
+
if len(all_encoder_out) > 1:
|
144 |
+
|
145 |
+
if "encoder_out" in all_encoder_out[0]:
|
146 |
+
encoder_out = {
|
147 |
+
"encoder_out": sum(e["encoder_out"] for e in all_encoder_out)
|
148 |
+
/ len(all_encoder_out),
|
149 |
+
"encoder_padding_mask": all_encoder_out[0]["encoder_padding_mask"],
|
150 |
+
}
|
151 |
+
padding = encoder_out["encoder_padding_mask"]
|
152 |
+
else:
|
153 |
+
encoder_out = {
|
154 |
+
"logits": sum(e["logits"] for e in all_encoder_out)
|
155 |
+
/ len(all_encoder_out),
|
156 |
+
"padding_mask": all_encoder_out[0]["padding_mask"],
|
157 |
+
}
|
158 |
+
padding = encoder_out["padding_mask"]
|
159 |
+
else:
|
160 |
+
encoder_out = all_encoder_out[0]
|
161 |
+
padding = (
|
162 |
+
encoder_out["padding_mask"]
|
163 |
+
if "padding_mask" in encoder_out
|
164 |
+
else encoder_out["encoder_padding_mask"]
|
165 |
+
)
|
166 |
+
|
167 |
+
if hasattr(model, "get_logits"):
|
168 |
+
emissions = model.get_logits(encoder_out, normalize=True)
|
169 |
+
else:
|
170 |
+
emissions = model.get_normalized_probs(encoder_out, log_probs=True)
|
171 |
+
|
172 |
+
return (
|
173 |
+
emissions.cpu().float().transpose(0, 1),
|
174 |
+
padding.cpu() if padding is not None and padding.any() else None,
|
175 |
+
)
|
176 |
+
|
177 |
+
def decode_one(self, logits, padding):
|
178 |
+
from kaldi.matrix import Matrix
|
179 |
+
|
180 |
+
decoder = self.dec_cls(self.fst, self.decoder_options)
|
181 |
+
asr = self.rec_cls(
|
182 |
+
decoder, self.symbol_table, acoustic_scale=self.acoustic_scale
|
183 |
+
)
|
184 |
+
|
185 |
+
if padding is not None:
|
186 |
+
logits = logits[~padding]
|
187 |
+
|
188 |
+
mat = Matrix(logits.numpy())
|
189 |
+
|
190 |
+
out = asr.decode(mat)
|
191 |
+
|
192 |
+
if self.nbest > 1:
|
193 |
+
from kaldi.fstext import shortestpath
|
194 |
+
from kaldi.fstext.utils import (
|
195 |
+
convert_compact_lattice_to_lattice,
|
196 |
+
convert_lattice_to_std,
|
197 |
+
convert_nbest_to_list,
|
198 |
+
get_linear_symbol_sequence,
|
199 |
+
)
|
200 |
+
|
201 |
+
lat = out["lattice"]
|
202 |
+
|
203 |
+
sp = shortestpath(lat, nshortest=self.nbest)
|
204 |
+
|
205 |
+
sp = convert_compact_lattice_to_lattice(sp)
|
206 |
+
sp = convert_lattice_to_std(sp)
|
207 |
+
seq = convert_nbest_to_list(sp)
|
208 |
+
|
209 |
+
results = []
|
210 |
+
for s in seq:
|
211 |
+
_, o, w = get_linear_symbol_sequence(s)
|
212 |
+
words = list(self.output_symbols[z] for z in o)
|
213 |
+
results.append(
|
214 |
+
{
|
215 |
+
"tokens": words,
|
216 |
+
"words": words,
|
217 |
+
"score": w.value,
|
218 |
+
"emissions": logits,
|
219 |
+
}
|
220 |
+
)
|
221 |
+
return results
|
222 |
+
else:
|
223 |
+
words = out["text"].split()
|
224 |
+
return [
|
225 |
+
{
|
226 |
+
"tokens": words,
|
227 |
+
"words": words,
|
228 |
+
"score": out["likelihood"],
|
229 |
+
"emissions": logits,
|
230 |
+
}
|
231 |
+
]
|
232 |
+
|
233 |
+
def decode(self, emissions, padding):
|
234 |
+
if padding is None:
|
235 |
+
padding = [None] * len(emissions)
|
236 |
+
|
237 |
+
ret = list(
|
238 |
+
map(
|
239 |
+
lambda e, p: self.executor.submit(self.decode_one, e, p),
|
240 |
+
emissions,
|
241 |
+
padding,
|
242 |
+
)
|
243 |
+
)
|
244 |
+
return ret
|
fairseq/examples/speech_recognition/kaldi/kaldi_initializer.py
ADDED
@@ -0,0 +1,698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the MIT license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
from dataclasses import dataclass
|
9 |
+
import hydra
|
10 |
+
from hydra.core.config_store import ConfigStore
|
11 |
+
import logging
|
12 |
+
from omegaconf import MISSING, OmegaConf
|
13 |
+
import os
|
14 |
+
import os.path as osp
|
15 |
+
from pathlib import Path
|
16 |
+
import subprocess
|
17 |
+
from typing import Optional
|
18 |
+
|
19 |
+
from fairseq.data.dictionary import Dictionary
|
20 |
+
from fairseq.dataclass import FairseqDataclass
|
21 |
+
|
22 |
+
script_dir = Path(__file__).resolve().parent
|
23 |
+
config_path = script_dir / "config"
|
24 |
+
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class KaldiInitializerConfig(FairseqDataclass):
|
31 |
+
data_dir: str = MISSING
|
32 |
+
fst_dir: Optional[str] = None
|
33 |
+
in_labels: str = MISSING
|
34 |
+
out_labels: Optional[str] = None
|
35 |
+
wav2letter_lexicon: Optional[str] = None
|
36 |
+
lm_arpa: str = MISSING
|
37 |
+
kaldi_root: str = MISSING
|
38 |
+
blank_symbol: str = "<s>"
|
39 |
+
silence_symbol: Optional[str] = None
|
40 |
+
|
41 |
+
|
42 |
+
def create_units(fst_dir: Path, in_labels: str, vocab: Dictionary) -> Path:
|
43 |
+
in_units_file = fst_dir / f"kaldi_dict.{in_labels}.txt"
|
44 |
+
if not in_units_file.exists():
|
45 |
+
|
46 |
+
logger.info(f"Creating {in_units_file}")
|
47 |
+
|
48 |
+
with open(in_units_file, "w") as f:
|
49 |
+
print("<eps> 0", file=f)
|
50 |
+
i = 1
|
51 |
+
for symb in vocab.symbols[vocab.nspecial :]:
|
52 |
+
if not symb.startswith("madeupword"):
|
53 |
+
print(f"{symb} {i}", file=f)
|
54 |
+
i += 1
|
55 |
+
return in_units_file
|
56 |
+
|
57 |
+
|
58 |
+
def create_lexicon(
|
59 |
+
cfg: KaldiInitializerConfig,
|
60 |
+
fst_dir: Path,
|
61 |
+
unique_label: str,
|
62 |
+
in_units_file: Path,
|
63 |
+
out_words_file: Path,
|
64 |
+
) -> (Path, Path):
|
65 |
+
|
66 |
+
disambig_in_units_file = fst_dir / f"kaldi_dict.{cfg.in_labels}_disambig.txt"
|
67 |
+
lexicon_file = fst_dir / f"kaldi_lexicon.{unique_label}.txt"
|
68 |
+
disambig_lexicon_file = fst_dir / f"kaldi_lexicon.{unique_label}_disambig.txt"
|
69 |
+
if (
|
70 |
+
not lexicon_file.exists()
|
71 |
+
or not disambig_lexicon_file.exists()
|
72 |
+
or not disambig_in_units_file.exists()
|
73 |
+
):
|
74 |
+
logger.info(f"Creating {lexicon_file} (in units file: {in_units_file})")
|
75 |
+
|
76 |
+
assert cfg.wav2letter_lexicon is not None or cfg.in_labels == cfg.out_labels
|
77 |
+
|
78 |
+
if cfg.wav2letter_lexicon is not None:
|
79 |
+
lm_words = set()
|
80 |
+
with open(out_words_file, "r") as lm_dict_f:
|
81 |
+
for line in lm_dict_f:
|
82 |
+
lm_words.add(line.split()[0])
|
83 |
+
|
84 |
+
num_skipped = 0
|
85 |
+
total = 0
|
86 |
+
with open(cfg.wav2letter_lexicon, "r") as w2l_lex_f, open(
|
87 |
+
lexicon_file, "w"
|
88 |
+
) as out_f:
|
89 |
+
for line in w2l_lex_f:
|
90 |
+
items = line.rstrip().split("\t")
|
91 |
+
assert len(items) == 2, items
|
92 |
+
if items[0] in lm_words:
|
93 |
+
print(items[0], items[1], file=out_f)
|
94 |
+
else:
|
95 |
+
num_skipped += 1
|
96 |
+
logger.debug(
|
97 |
+
f"Skipping word {items[0]} as it was not found in LM"
|
98 |
+
)
|
99 |
+
total += 1
|
100 |
+
if num_skipped > 0:
|
101 |
+
logger.warning(
|
102 |
+
f"Skipped {num_skipped} out of {total} words as they were not found in LM"
|
103 |
+
)
|
104 |
+
else:
|
105 |
+
with open(in_units_file, "r") as in_f, open(lexicon_file, "w") as out_f:
|
106 |
+
for line in in_f:
|
107 |
+
symb = line.split()[0]
|
108 |
+
if symb != "<eps>" and symb != "<ctc_blank>" and symb != "<SIL>":
|
109 |
+
print(symb, symb, file=out_f)
|
110 |
+
|
111 |
+
lex_disambig_path = (
|
112 |
+
Path(cfg.kaldi_root) / "egs/wsj/s5/utils/add_lex_disambig.pl"
|
113 |
+
)
|
114 |
+
res = subprocess.run(
|
115 |
+
[lex_disambig_path, lexicon_file, disambig_lexicon_file],
|
116 |
+
check=True,
|
117 |
+
capture_output=True,
|
118 |
+
)
|
119 |
+
ndisambig = int(res.stdout)
|
120 |
+
disamib_path = Path(cfg.kaldi_root) / "egs/wsj/s5/utils/add_disambig.pl"
|
121 |
+
res = subprocess.run(
|
122 |
+
[disamib_path, "--include-zero", in_units_file, str(ndisambig)],
|
123 |
+
check=True,
|
124 |
+
capture_output=True,
|
125 |
+
)
|
126 |
+
with open(disambig_in_units_file, "wb") as f:
|
127 |
+
f.write(res.stdout)
|
128 |
+
|
129 |
+
return disambig_lexicon_file, disambig_in_units_file
|
130 |
+
|
131 |
+
|
132 |
+
def create_G(
|
133 |
+
kaldi_root: Path, fst_dir: Path, lm_arpa: Path, arpa_base: str
|
134 |
+
) -> (Path, Path):
|
135 |
+
|
136 |
+
out_words_file = fst_dir / f"kaldi_dict.{arpa_base}.txt"
|
137 |
+
grammar_graph = fst_dir / f"G_{arpa_base}.fst"
|
138 |
+
if not grammar_graph.exists() or not out_words_file.exists():
|
139 |
+
logger.info(f"Creating {grammar_graph}")
|
140 |
+
arpa2fst = kaldi_root / "src/lmbin/arpa2fst"
|
141 |
+
subprocess.run(
|
142 |
+
[
|
143 |
+
arpa2fst,
|
144 |
+
"--disambig-symbol=#0",
|
145 |
+
f"--write-symbol-table={out_words_file}",
|
146 |
+
lm_arpa,
|
147 |
+
grammar_graph,
|
148 |
+
],
|
149 |
+
check=True,
|
150 |
+
)
|
151 |
+
return grammar_graph, out_words_file
|
152 |
+
|
153 |
+
|
154 |
+
def create_L(
|
155 |
+
kaldi_root: Path,
|
156 |
+
fst_dir: Path,
|
157 |
+
unique_label: str,
|
158 |
+
lexicon_file: Path,
|
159 |
+
in_units_file: Path,
|
160 |
+
out_words_file: Path,
|
161 |
+
) -> Path:
|
162 |
+
lexicon_graph = fst_dir / f"L.{unique_label}.fst"
|
163 |
+
|
164 |
+
if not lexicon_graph.exists():
|
165 |
+
logger.info(f"Creating {lexicon_graph} (in units: {in_units_file})")
|
166 |
+
make_lex = kaldi_root / "egs/wsj/s5/utils/make_lexicon_fst.pl"
|
167 |
+
fstcompile = kaldi_root / "tools/openfst-1.6.7/bin/fstcompile"
|
168 |
+
fstaddselfloops = kaldi_root / "src/fstbin/fstaddselfloops"
|
169 |
+
fstarcsort = kaldi_root / "tools/openfst-1.6.7/bin/fstarcsort"
|
170 |
+
|
171 |
+
def write_disambig_symbol(file):
|
172 |
+
with open(file, "r") as f:
|
173 |
+
for line in f:
|
174 |
+
items = line.rstrip().split()
|
175 |
+
if items[0] == "#0":
|
176 |
+
out_path = str(file) + "_disamig"
|
177 |
+
with open(out_path, "w") as out_f:
|
178 |
+
print(items[1], file=out_f)
|
179 |
+
return out_path
|
180 |
+
|
181 |
+
return None
|
182 |
+
|
183 |
+
in_disambig_sym = write_disambig_symbol(in_units_file)
|
184 |
+
assert in_disambig_sym is not None
|
185 |
+
out_disambig_sym = write_disambig_symbol(out_words_file)
|
186 |
+
assert out_disambig_sym is not None
|
187 |
+
|
188 |
+
try:
|
189 |
+
with open(lexicon_graph, "wb") as out_f:
|
190 |
+
res = subprocess.run(
|
191 |
+
[make_lex, lexicon_file], capture_output=True, check=True
|
192 |
+
)
|
193 |
+
assert len(res.stderr) == 0, res.stderr.decode("utf-8")
|
194 |
+
res = subprocess.run(
|
195 |
+
[
|
196 |
+
fstcompile,
|
197 |
+
f"--isymbols={in_units_file}",
|
198 |
+
f"--osymbols={out_words_file}",
|
199 |
+
"--keep_isymbols=false",
|
200 |
+
"--keep_osymbols=false",
|
201 |
+
],
|
202 |
+
input=res.stdout,
|
203 |
+
capture_output=True,
|
204 |
+
)
|
205 |
+
assert len(res.stderr) == 0, res.stderr.decode("utf-8")
|
206 |
+
res = subprocess.run(
|
207 |
+
[fstaddselfloops, in_disambig_sym, out_disambig_sym],
|
208 |
+
input=res.stdout,
|
209 |
+
capture_output=True,
|
210 |
+
check=True,
|
211 |
+
)
|
212 |
+
res = subprocess.run(
|
213 |
+
[fstarcsort, "--sort_type=olabel"],
|
214 |
+
input=res.stdout,
|
215 |
+
capture_output=True,
|
216 |
+
check=True,
|
217 |
+
)
|
218 |
+
out_f.write(res.stdout)
|
219 |
+
except subprocess.CalledProcessError as e:
|
220 |
+
logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}")
|
221 |
+
os.remove(lexicon_graph)
|
222 |
+
raise
|
223 |
+
except AssertionError:
|
224 |
+
os.remove(lexicon_graph)
|
225 |
+
raise
|
226 |
+
|
227 |
+
return lexicon_graph
|
228 |
+
|
229 |
+
|
230 |
+
def create_LG(
|
231 |
+
kaldi_root: Path,
|
232 |
+
fst_dir: Path,
|
233 |
+
unique_label: str,
|
234 |
+
lexicon_graph: Path,
|
235 |
+
grammar_graph: Path,
|
236 |
+
) -> Path:
|
237 |
+
lg_graph = fst_dir / f"LG.{unique_label}.fst"
|
238 |
+
|
239 |
+
if not lg_graph.exists():
|
240 |
+
logger.info(f"Creating {lg_graph}")
|
241 |
+
|
242 |
+
fsttablecompose = kaldi_root / "src/fstbin/fsttablecompose"
|
243 |
+
fstdeterminizestar = kaldi_root / "src/fstbin/fstdeterminizestar"
|
244 |
+
fstminimizeencoded = kaldi_root / "src/fstbin/fstminimizeencoded"
|
245 |
+
fstpushspecial = kaldi_root / "src/fstbin/fstpushspecial"
|
246 |
+
fstarcsort = kaldi_root / "tools/openfst-1.6.7/bin/fstarcsort"
|
247 |
+
|
248 |
+
try:
|
249 |
+
with open(lg_graph, "wb") as out_f:
|
250 |
+
res = subprocess.run(
|
251 |
+
[fsttablecompose, lexicon_graph, grammar_graph],
|
252 |
+
capture_output=True,
|
253 |
+
check=True,
|
254 |
+
)
|
255 |
+
res = subprocess.run(
|
256 |
+
[
|
257 |
+
fstdeterminizestar,
|
258 |
+
"--use-log=true",
|
259 |
+
],
|
260 |
+
input=res.stdout,
|
261 |
+
capture_output=True,
|
262 |
+
)
|
263 |
+
res = subprocess.run(
|
264 |
+
[fstminimizeencoded],
|
265 |
+
input=res.stdout,
|
266 |
+
capture_output=True,
|
267 |
+
check=True,
|
268 |
+
)
|
269 |
+
res = subprocess.run(
|
270 |
+
[fstpushspecial],
|
271 |
+
input=res.stdout,
|
272 |
+
capture_output=True,
|
273 |
+
check=True,
|
274 |
+
)
|
275 |
+
res = subprocess.run(
|
276 |
+
[fstarcsort, "--sort_type=ilabel"],
|
277 |
+
input=res.stdout,
|
278 |
+
capture_output=True,
|
279 |
+
check=True,
|
280 |
+
)
|
281 |
+
out_f.write(res.stdout)
|
282 |
+
except subprocess.CalledProcessError as e:
|
283 |
+
logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}")
|
284 |
+
os.remove(lg_graph)
|
285 |
+
raise
|
286 |
+
|
287 |
+
return lg_graph
|
288 |
+
|
289 |
+
|
290 |
+
def create_H(
|
291 |
+
kaldi_root: Path,
|
292 |
+
fst_dir: Path,
|
293 |
+
disambig_out_units_file: Path,
|
294 |
+
in_labels: str,
|
295 |
+
vocab: Dictionary,
|
296 |
+
blk_sym: str,
|
297 |
+
silence_symbol: Optional[str],
|
298 |
+
) -> (Path, Path, Path):
|
299 |
+
h_graph = (
|
300 |
+
fst_dir / f"H.{in_labels}{'_' + silence_symbol if silence_symbol else ''}.fst"
|
301 |
+
)
|
302 |
+
h_out_units_file = fst_dir / f"kaldi_dict.h_out.{in_labels}.txt"
|
303 |
+
disambig_in_units_file_int = Path(str(h_graph) + "isym_disambig.int")
|
304 |
+
disambig_out_units_file_int = Path(str(disambig_out_units_file) + ".int")
|
305 |
+
if (
|
306 |
+
not h_graph.exists()
|
307 |
+
or not h_out_units_file.exists()
|
308 |
+
or not disambig_in_units_file_int.exists()
|
309 |
+
):
|
310 |
+
logger.info(f"Creating {h_graph}")
|
311 |
+
eps_sym = "<eps>"
|
312 |
+
|
313 |
+
num_disambig = 0
|
314 |
+
osymbols = []
|
315 |
+
|
316 |
+
with open(disambig_out_units_file, "r") as f, open(
|
317 |
+
disambig_out_units_file_int, "w"
|
318 |
+
) as out_f:
|
319 |
+
for line in f:
|
320 |
+
symb, id = line.rstrip().split()
|
321 |
+
if line.startswith("#"):
|
322 |
+
num_disambig += 1
|
323 |
+
print(id, file=out_f)
|
324 |
+
else:
|
325 |
+
if len(osymbols) == 0:
|
326 |
+
assert symb == eps_sym, symb
|
327 |
+
osymbols.append((symb, id))
|
328 |
+
|
329 |
+
i_idx = 0
|
330 |
+
isymbols = [(eps_sym, 0)]
|
331 |
+
|
332 |
+
imap = {}
|
333 |
+
|
334 |
+
for i, s in enumerate(vocab.symbols):
|
335 |
+
i_idx += 1
|
336 |
+
isymbols.append((s, i_idx))
|
337 |
+
imap[s] = i_idx
|
338 |
+
|
339 |
+
fst_str = []
|
340 |
+
|
341 |
+
node_idx = 0
|
342 |
+
root_node = node_idx
|
343 |
+
|
344 |
+
special_symbols = [blk_sym]
|
345 |
+
if silence_symbol is not None:
|
346 |
+
special_symbols.append(silence_symbol)
|
347 |
+
|
348 |
+
for ss in special_symbols:
|
349 |
+
fst_str.append("{} {} {} {}".format(root_node, root_node, ss, eps_sym))
|
350 |
+
|
351 |
+
for symbol, _ in osymbols:
|
352 |
+
if symbol == eps_sym or symbol.startswith("#"):
|
353 |
+
continue
|
354 |
+
|
355 |
+
node_idx += 1
|
356 |
+
# 1. from root to emitting state
|
357 |
+
fst_str.append("{} {} {} {}".format(root_node, node_idx, symbol, symbol))
|
358 |
+
# 2. from emitting state back to root
|
359 |
+
fst_str.append("{} {} {} {}".format(node_idx, root_node, eps_sym, eps_sym))
|
360 |
+
# 3. from emitting state to optional blank state
|
361 |
+
pre_node = node_idx
|
362 |
+
node_idx += 1
|
363 |
+
for ss in special_symbols:
|
364 |
+
fst_str.append("{} {} {} {}".format(pre_node, node_idx, ss, eps_sym))
|
365 |
+
# 4. from blank state back to root
|
366 |
+
fst_str.append("{} {} {} {}".format(node_idx, root_node, eps_sym, eps_sym))
|
367 |
+
|
368 |
+
fst_str.append("{}".format(root_node))
|
369 |
+
|
370 |
+
fst_str = "\n".join(fst_str)
|
371 |
+
h_str = str(h_graph)
|
372 |
+
isym_file = h_str + ".isym"
|
373 |
+
|
374 |
+
with open(isym_file, "w") as f:
|
375 |
+
for sym, id in isymbols:
|
376 |
+
f.write("{} {}\n".format(sym, id))
|
377 |
+
|
378 |
+
with open(h_out_units_file, "w") as f:
|
379 |
+
for sym, id in osymbols:
|
380 |
+
f.write("{} {}\n".format(sym, id))
|
381 |
+
|
382 |
+
with open(disambig_in_units_file_int, "w") as f:
|
383 |
+
disam_sym_id = len(isymbols)
|
384 |
+
for _ in range(num_disambig):
|
385 |
+
f.write("{}\n".format(disam_sym_id))
|
386 |
+
disam_sym_id += 1
|
387 |
+
|
388 |
+
fstcompile = kaldi_root / "tools/openfst-1.6.7/bin/fstcompile"
|
389 |
+
fstaddselfloops = kaldi_root / "src/fstbin/fstaddselfloops"
|
390 |
+
fstarcsort = kaldi_root / "tools/openfst-1.6.7/bin/fstarcsort"
|
391 |
+
|
392 |
+
try:
|
393 |
+
with open(h_graph, "wb") as out_f:
|
394 |
+
res = subprocess.run(
|
395 |
+
[
|
396 |
+
fstcompile,
|
397 |
+
f"--isymbols={isym_file}",
|
398 |
+
f"--osymbols={h_out_units_file}",
|
399 |
+
"--keep_isymbols=false",
|
400 |
+
"--keep_osymbols=false",
|
401 |
+
],
|
402 |
+
input=str.encode(fst_str),
|
403 |
+
capture_output=True,
|
404 |
+
check=True,
|
405 |
+
)
|
406 |
+
res = subprocess.run(
|
407 |
+
[
|
408 |
+
fstaddselfloops,
|
409 |
+
disambig_in_units_file_int,
|
410 |
+
disambig_out_units_file_int,
|
411 |
+
],
|
412 |
+
input=res.stdout,
|
413 |
+
capture_output=True,
|
414 |
+
check=True,
|
415 |
+
)
|
416 |
+
res = subprocess.run(
|
417 |
+
[fstarcsort, "--sort_type=olabel"],
|
418 |
+
input=res.stdout,
|
419 |
+
capture_output=True,
|
420 |
+
check=True,
|
421 |
+
)
|
422 |
+
out_f.write(res.stdout)
|
423 |
+
except subprocess.CalledProcessError as e:
|
424 |
+
logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}")
|
425 |
+
os.remove(h_graph)
|
426 |
+
raise
|
427 |
+
return h_graph, h_out_units_file, disambig_in_units_file_int
|
428 |
+
|
429 |
+
|
430 |
+
def create_HLGa(
|
431 |
+
kaldi_root: Path,
|
432 |
+
fst_dir: Path,
|
433 |
+
unique_label: str,
|
434 |
+
h_graph: Path,
|
435 |
+
lg_graph: Path,
|
436 |
+
disambig_in_words_file_int: Path,
|
437 |
+
) -> Path:
|
438 |
+
hlga_graph = fst_dir / f"HLGa.{unique_label}.fst"
|
439 |
+
|
440 |
+
if not hlga_graph.exists():
|
441 |
+
logger.info(f"Creating {hlga_graph}")
|
442 |
+
|
443 |
+
fsttablecompose = kaldi_root / "src/fstbin/fsttablecompose"
|
444 |
+
fstdeterminizestar = kaldi_root / "src/fstbin/fstdeterminizestar"
|
445 |
+
fstrmsymbols = kaldi_root / "src/fstbin/fstrmsymbols"
|
446 |
+
fstrmepslocal = kaldi_root / "src/fstbin/fstrmepslocal"
|
447 |
+
fstminimizeencoded = kaldi_root / "src/fstbin/fstminimizeencoded"
|
448 |
+
|
449 |
+
try:
|
450 |
+
with open(hlga_graph, "wb") as out_f:
|
451 |
+
res = subprocess.run(
|
452 |
+
[
|
453 |
+
fsttablecompose,
|
454 |
+
h_graph,
|
455 |
+
lg_graph,
|
456 |
+
],
|
457 |
+
capture_output=True,
|
458 |
+
check=True,
|
459 |
+
)
|
460 |
+
res = subprocess.run(
|
461 |
+
[fstdeterminizestar, "--use-log=true"],
|
462 |
+
input=res.stdout,
|
463 |
+
capture_output=True,
|
464 |
+
check=True,
|
465 |
+
)
|
466 |
+
res = subprocess.run(
|
467 |
+
[fstrmsymbols, disambig_in_words_file_int],
|
468 |
+
input=res.stdout,
|
469 |
+
capture_output=True,
|
470 |
+
check=True,
|
471 |
+
)
|
472 |
+
res = subprocess.run(
|
473 |
+
[fstrmepslocal],
|
474 |
+
input=res.stdout,
|
475 |
+
capture_output=True,
|
476 |
+
check=True,
|
477 |
+
)
|
478 |
+
res = subprocess.run(
|
479 |
+
[fstminimizeencoded],
|
480 |
+
input=res.stdout,
|
481 |
+
capture_output=True,
|
482 |
+
check=True,
|
483 |
+
)
|
484 |
+
out_f.write(res.stdout)
|
485 |
+
except subprocess.CalledProcessError as e:
|
486 |
+
logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}")
|
487 |
+
os.remove(hlga_graph)
|
488 |
+
raise
|
489 |
+
|
490 |
+
return hlga_graph
|
491 |
+
|
492 |
+
|
493 |
+
def create_HLa(
|
494 |
+
kaldi_root: Path,
|
495 |
+
fst_dir: Path,
|
496 |
+
unique_label: str,
|
497 |
+
h_graph: Path,
|
498 |
+
l_graph: Path,
|
499 |
+
disambig_in_words_file_int: Path,
|
500 |
+
) -> Path:
|
501 |
+
hla_graph = fst_dir / f"HLa.{unique_label}.fst"
|
502 |
+
|
503 |
+
if not hla_graph.exists():
|
504 |
+
logger.info(f"Creating {hla_graph}")
|
505 |
+
|
506 |
+
fsttablecompose = kaldi_root / "src/fstbin/fsttablecompose"
|
507 |
+
fstdeterminizestar = kaldi_root / "src/fstbin/fstdeterminizestar"
|
508 |
+
fstrmsymbols = kaldi_root / "src/fstbin/fstrmsymbols"
|
509 |
+
fstrmepslocal = kaldi_root / "src/fstbin/fstrmepslocal"
|
510 |
+
fstminimizeencoded = kaldi_root / "src/fstbin/fstminimizeencoded"
|
511 |
+
|
512 |
+
try:
|
513 |
+
with open(hla_graph, "wb") as out_f:
|
514 |
+
res = subprocess.run(
|
515 |
+
[
|
516 |
+
fsttablecompose,
|
517 |
+
h_graph,
|
518 |
+
l_graph,
|
519 |
+
],
|
520 |
+
capture_output=True,
|
521 |
+
check=True,
|
522 |
+
)
|
523 |
+
res = subprocess.run(
|
524 |
+
[fstdeterminizestar, "--use-log=true"],
|
525 |
+
input=res.stdout,
|
526 |
+
capture_output=True,
|
527 |
+
check=True,
|
528 |
+
)
|
529 |
+
res = subprocess.run(
|
530 |
+
[fstrmsymbols, disambig_in_words_file_int],
|
531 |
+
input=res.stdout,
|
532 |
+
capture_output=True,
|
533 |
+
check=True,
|
534 |
+
)
|
535 |
+
res = subprocess.run(
|
536 |
+
[fstrmepslocal],
|
537 |
+
input=res.stdout,
|
538 |
+
capture_output=True,
|
539 |
+
check=True,
|
540 |
+
)
|
541 |
+
res = subprocess.run(
|
542 |
+
[fstminimizeencoded],
|
543 |
+
input=res.stdout,
|
544 |
+
capture_output=True,
|
545 |
+
check=True,
|
546 |
+
)
|
547 |
+
out_f.write(res.stdout)
|
548 |
+
except subprocess.CalledProcessError as e:
|
549 |
+
logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}")
|
550 |
+
os.remove(hla_graph)
|
551 |
+
raise
|
552 |
+
|
553 |
+
return hla_graph
|
554 |
+
|
555 |
+
|
556 |
+
def create_HLG(
|
557 |
+
kaldi_root: Path,
|
558 |
+
fst_dir: Path,
|
559 |
+
unique_label: str,
|
560 |
+
hlga_graph: Path,
|
561 |
+
prefix: str = "HLG",
|
562 |
+
) -> Path:
|
563 |
+
hlg_graph = fst_dir / f"{prefix}.{unique_label}.fst"
|
564 |
+
|
565 |
+
if not hlg_graph.exists():
|
566 |
+
logger.info(f"Creating {hlg_graph}")
|
567 |
+
|
568 |
+
add_self_loop = script_dir / "add-self-loop-simple"
|
569 |
+
kaldi_src = kaldi_root / "src"
|
570 |
+
kaldi_lib = kaldi_src / "lib"
|
571 |
+
|
572 |
+
try:
|
573 |
+
if not add_self_loop.exists():
|
574 |
+
fst_include = kaldi_root / "tools/openfst-1.6.7/include"
|
575 |
+
add_self_loop_src = script_dir / "add-self-loop-simple.cc"
|
576 |
+
|
577 |
+
subprocess.run(
|
578 |
+
[
|
579 |
+
"c++",
|
580 |
+
f"-I{kaldi_src}",
|
581 |
+
f"-I{fst_include}",
|
582 |
+
f"-L{kaldi_lib}",
|
583 |
+
add_self_loop_src,
|
584 |
+
"-lkaldi-base",
|
585 |
+
"-lkaldi-fstext",
|
586 |
+
"-o",
|
587 |
+
add_self_loop,
|
588 |
+
],
|
589 |
+
check=True,
|
590 |
+
)
|
591 |
+
|
592 |
+
my_env = os.environ.copy()
|
593 |
+
my_env["LD_LIBRARY_PATH"] = f"{kaldi_lib}:{my_env['LD_LIBRARY_PATH']}"
|
594 |
+
|
595 |
+
subprocess.run(
|
596 |
+
[
|
597 |
+
add_self_loop,
|
598 |
+
hlga_graph,
|
599 |
+
hlg_graph,
|
600 |
+
],
|
601 |
+
check=True,
|
602 |
+
capture_output=True,
|
603 |
+
env=my_env,
|
604 |
+
)
|
605 |
+
except subprocess.CalledProcessError as e:
|
606 |
+
logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}")
|
607 |
+
raise
|
608 |
+
|
609 |
+
return hlg_graph
|
610 |
+
|
611 |
+
|
612 |
+
def initalize_kaldi(cfg: KaldiInitializerConfig) -> Path:
|
613 |
+
if cfg.fst_dir is None:
|
614 |
+
cfg.fst_dir = osp.join(cfg.data_dir, "kaldi")
|
615 |
+
if cfg.out_labels is None:
|
616 |
+
cfg.out_labels = cfg.in_labels
|
617 |
+
|
618 |
+
kaldi_root = Path(cfg.kaldi_root)
|
619 |
+
data_dir = Path(cfg.data_dir)
|
620 |
+
fst_dir = Path(cfg.fst_dir)
|
621 |
+
fst_dir.mkdir(parents=True, exist_ok=True)
|
622 |
+
|
623 |
+
arpa_base = osp.splitext(osp.basename(cfg.lm_arpa))[0]
|
624 |
+
unique_label = f"{cfg.in_labels}.{arpa_base}"
|
625 |
+
|
626 |
+
with open(data_dir / f"dict.{cfg.in_labels}.txt", "r") as f:
|
627 |
+
vocab = Dictionary.load(f)
|
628 |
+
|
629 |
+
in_units_file = create_units(fst_dir, cfg.in_labels, vocab)
|
630 |
+
|
631 |
+
grammar_graph, out_words_file = create_G(
|
632 |
+
kaldi_root, fst_dir, Path(cfg.lm_arpa), arpa_base
|
633 |
+
)
|
634 |
+
|
635 |
+
disambig_lexicon_file, disambig_L_in_units_file = create_lexicon(
|
636 |
+
cfg, fst_dir, unique_label, in_units_file, out_words_file
|
637 |
+
)
|
638 |
+
|
639 |
+
h_graph, h_out_units_file, disambig_in_units_file_int = create_H(
|
640 |
+
kaldi_root,
|
641 |
+
fst_dir,
|
642 |
+
disambig_L_in_units_file,
|
643 |
+
cfg.in_labels,
|
644 |
+
vocab,
|
645 |
+
cfg.blank_symbol,
|
646 |
+
cfg.silence_symbol,
|
647 |
+
)
|
648 |
+
lexicon_graph = create_L(
|
649 |
+
kaldi_root,
|
650 |
+
fst_dir,
|
651 |
+
unique_label,
|
652 |
+
disambig_lexicon_file,
|
653 |
+
disambig_L_in_units_file,
|
654 |
+
out_words_file,
|
655 |
+
)
|
656 |
+
lg_graph = create_LG(
|
657 |
+
kaldi_root, fst_dir, unique_label, lexicon_graph, grammar_graph
|
658 |
+
)
|
659 |
+
hlga_graph = create_HLGa(
|
660 |
+
kaldi_root, fst_dir, unique_label, h_graph, lg_graph, disambig_in_units_file_int
|
661 |
+
)
|
662 |
+
hlg_graph = create_HLG(kaldi_root, fst_dir, unique_label, hlga_graph)
|
663 |
+
|
664 |
+
# for debugging
|
665 |
+
# hla_graph = create_HLa(kaldi_root, fst_dir, unique_label, h_graph, lexicon_graph, disambig_in_units_file_int)
|
666 |
+
# hl_graph = create_HLG(kaldi_root, fst_dir, unique_label, hla_graph, prefix="HL_looped")
|
667 |
+
# create_HLG(kaldi_root, fst_dir, "phnc", h_graph, prefix="H_looped")
|
668 |
+
|
669 |
+
return hlg_graph
|
670 |
+
|
671 |
+
|
672 |
+
@hydra.main(config_path=config_path, config_name="kaldi_initializer")
|
673 |
+
def cli_main(cfg: KaldiInitializerConfig) -> None:
|
674 |
+
container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)
|
675 |
+
cfg = OmegaConf.create(container)
|
676 |
+
OmegaConf.set_struct(cfg, True)
|
677 |
+
initalize_kaldi(cfg)
|
678 |
+
|
679 |
+
|
680 |
+
if __name__ == "__main__":
|
681 |
+
|
682 |
+
logging.root.setLevel(logging.INFO)
|
683 |
+
logging.basicConfig(level=logging.INFO)
|
684 |
+
|
685 |
+
try:
|
686 |
+
from hydra._internal.utils import (
|
687 |
+
get_args,
|
688 |
+
) # pylint: disable=import-outside-toplevel
|
689 |
+
|
690 |
+
cfg_name = get_args().config_name or "kaldi_initializer"
|
691 |
+
except ImportError:
|
692 |
+
logger.warning("Failed to get config name from hydra args")
|
693 |
+
cfg_name = "kaldi_initializer"
|
694 |
+
|
695 |
+
cs = ConfigStore.instance()
|
696 |
+
cs.store(name=cfg_name, node=KaldiInitializerConfig)
|
697 |
+
|
698 |
+
cli_main()
|
fairseq/examples/speech_recognition/models/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import os
|
3 |
+
|
4 |
+
|
5 |
+
for file in sorted(os.listdir(os.path.dirname(__file__))):
|
6 |
+
if file.endswith(".py") and not file.startswith("_"):
|
7 |
+
model_name = file[: file.find(".py")]
|
8 |
+
importlib.import_module("examples.speech_recognition.models." + model_name)
|
fairseq/examples/speech_recognition/models/vggtransformer.py
ADDED
@@ -0,0 +1,1020 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import math
|
8 |
+
from collections.abc import Iterable
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask
|
13 |
+
from fairseq import utils
|
14 |
+
from fairseq.models import (
|
15 |
+
FairseqEncoder,
|
16 |
+
FairseqEncoderDecoderModel,
|
17 |
+
FairseqEncoderModel,
|
18 |
+
FairseqIncrementalDecoder,
|
19 |
+
register_model,
|
20 |
+
register_model_architecture,
|
21 |
+
)
|
22 |
+
from fairseq.modules import (
|
23 |
+
LinearizedConvolution,
|
24 |
+
TransformerDecoderLayer,
|
25 |
+
TransformerEncoderLayer,
|
26 |
+
VGGBlock,
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
@register_model("asr_vggtransformer")
|
31 |
+
class VGGTransformerModel(FairseqEncoderDecoderModel):
|
32 |
+
"""
|
33 |
+
Transformers with convolutional context for ASR
|
34 |
+
https://arxiv.org/abs/1904.11660
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, encoder, decoder):
|
38 |
+
super().__init__(encoder, decoder)
|
39 |
+
|
40 |
+
@staticmethod
|
41 |
+
def add_args(parser):
|
42 |
+
"""Add model-specific arguments to the parser."""
|
43 |
+
parser.add_argument(
|
44 |
+
"--input-feat-per-channel",
|
45 |
+
type=int,
|
46 |
+
metavar="N",
|
47 |
+
help="encoder input dimension per input channel",
|
48 |
+
)
|
49 |
+
parser.add_argument(
|
50 |
+
"--vggblock-enc-config",
|
51 |
+
type=str,
|
52 |
+
metavar="EXPR",
|
53 |
+
help="""
|
54 |
+
an array of tuples each containing the configuration of one vggblock:
|
55 |
+
[(out_channels,
|
56 |
+
conv_kernel_size,
|
57 |
+
pooling_kernel_size,
|
58 |
+
num_conv_layers,
|
59 |
+
use_layer_norm), ...])
|
60 |
+
""",
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--transformer-enc-config",
|
64 |
+
type=str,
|
65 |
+
metavar="EXPR",
|
66 |
+
help=""""
|
67 |
+
a tuple containing the configuration of the encoder transformer layers
|
68 |
+
configurations:
|
69 |
+
[(input_dim,
|
70 |
+
num_heads,
|
71 |
+
ffn_dim,
|
72 |
+
normalize_before,
|
73 |
+
dropout,
|
74 |
+
attention_dropout,
|
75 |
+
relu_dropout), ...]')
|
76 |
+
""",
|
77 |
+
)
|
78 |
+
parser.add_argument(
|
79 |
+
"--enc-output-dim",
|
80 |
+
type=int,
|
81 |
+
metavar="N",
|
82 |
+
help="""
|
83 |
+
encoder output dimension, can be None. If specified, projecting the
|
84 |
+
transformer output to the specified dimension""",
|
85 |
+
)
|
86 |
+
parser.add_argument(
|
87 |
+
"--in-channels",
|
88 |
+
type=int,
|
89 |
+
metavar="N",
|
90 |
+
help="number of encoder input channels",
|
91 |
+
)
|
92 |
+
parser.add_argument(
|
93 |
+
"--tgt-embed-dim",
|
94 |
+
type=int,
|
95 |
+
metavar="N",
|
96 |
+
help="embedding dimension of the decoder target tokens",
|
97 |
+
)
|
98 |
+
parser.add_argument(
|
99 |
+
"--transformer-dec-config",
|
100 |
+
type=str,
|
101 |
+
metavar="EXPR",
|
102 |
+
help="""
|
103 |
+
a tuple containing the configuration of the decoder transformer layers
|
104 |
+
configurations:
|
105 |
+
[(input_dim,
|
106 |
+
num_heads,
|
107 |
+
ffn_dim,
|
108 |
+
normalize_before,
|
109 |
+
dropout,
|
110 |
+
attention_dropout,
|
111 |
+
relu_dropout), ...]
|
112 |
+
""",
|
113 |
+
)
|
114 |
+
parser.add_argument(
|
115 |
+
"--conv-dec-config",
|
116 |
+
type=str,
|
117 |
+
metavar="EXPR",
|
118 |
+
help="""
|
119 |
+
an array of tuples for the decoder 1-D convolution config
|
120 |
+
[(out_channels, conv_kernel_size, use_layer_norm), ...]""",
|
121 |
+
)
|
122 |
+
|
123 |
+
@classmethod
|
124 |
+
def build_encoder(cls, args, task):
|
125 |
+
return VGGTransformerEncoder(
|
126 |
+
input_feat_per_channel=args.input_feat_per_channel,
|
127 |
+
vggblock_config=eval(args.vggblock_enc_config),
|
128 |
+
transformer_config=eval(args.transformer_enc_config),
|
129 |
+
encoder_output_dim=args.enc_output_dim,
|
130 |
+
in_channels=args.in_channels,
|
131 |
+
)
|
132 |
+
|
133 |
+
@classmethod
|
134 |
+
def build_decoder(cls, args, task):
|
135 |
+
return TransformerDecoder(
|
136 |
+
dictionary=task.target_dictionary,
|
137 |
+
embed_dim=args.tgt_embed_dim,
|
138 |
+
transformer_config=eval(args.transformer_dec_config),
|
139 |
+
conv_config=eval(args.conv_dec_config),
|
140 |
+
encoder_output_dim=args.enc_output_dim,
|
141 |
+
)
|
142 |
+
|
143 |
+
@classmethod
|
144 |
+
def build_model(cls, args, task):
|
145 |
+
"""Build a new model instance."""
|
146 |
+
# make sure that all args are properly defaulted
|
147 |
+
# (in case there are any new ones)
|
148 |
+
base_architecture(args)
|
149 |
+
|
150 |
+
encoder = cls.build_encoder(args, task)
|
151 |
+
decoder = cls.build_decoder(args, task)
|
152 |
+
return cls(encoder, decoder)
|
153 |
+
|
154 |
+
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
155 |
+
# net_output['encoder_out'] is a (B, T, D) tensor
|
156 |
+
lprobs = super().get_normalized_probs(net_output, log_probs, sample)
|
157 |
+
lprobs.batch_first = True
|
158 |
+
return lprobs
|
159 |
+
|
160 |
+
|
161 |
+
DEFAULT_ENC_VGGBLOCK_CONFIG = ((32, 3, 2, 2, False),) * 2
|
162 |
+
DEFAULT_ENC_TRANSFORMER_CONFIG = ((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2
|
163 |
+
# 256: embedding dimension
|
164 |
+
# 4: number of heads
|
165 |
+
# 1024: FFN
|
166 |
+
# True: apply layerNorm before (dropout + resiaul) instead of after
|
167 |
+
# 0.2 (dropout): dropout after MultiheadAttention and second FC
|
168 |
+
# 0.2 (attention_dropout): dropout in MultiheadAttention
|
169 |
+
# 0.2 (relu_dropout): dropout after ReLu
|
170 |
+
DEFAULT_DEC_TRANSFORMER_CONFIG = ((256, 2, 1024, True, 0.2, 0.2, 0.2),) * 2
|
171 |
+
DEFAULT_DEC_CONV_CONFIG = ((256, 3, True),) * 2
|
172 |
+
|
173 |
+
|
174 |
+
# TODO: repace transformer encoder config from one liner
|
175 |
+
# to explicit args to get rid of this transformation
|
176 |
+
def prepare_transformer_encoder_params(
|
177 |
+
input_dim,
|
178 |
+
num_heads,
|
179 |
+
ffn_dim,
|
180 |
+
normalize_before,
|
181 |
+
dropout,
|
182 |
+
attention_dropout,
|
183 |
+
relu_dropout,
|
184 |
+
):
|
185 |
+
args = argparse.Namespace()
|
186 |
+
args.encoder_embed_dim = input_dim
|
187 |
+
args.encoder_attention_heads = num_heads
|
188 |
+
args.attention_dropout = attention_dropout
|
189 |
+
args.dropout = dropout
|
190 |
+
args.activation_dropout = relu_dropout
|
191 |
+
args.encoder_normalize_before = normalize_before
|
192 |
+
args.encoder_ffn_embed_dim = ffn_dim
|
193 |
+
return args
|
194 |
+
|
195 |
+
|
196 |
+
def prepare_transformer_decoder_params(
|
197 |
+
input_dim,
|
198 |
+
num_heads,
|
199 |
+
ffn_dim,
|
200 |
+
normalize_before,
|
201 |
+
dropout,
|
202 |
+
attention_dropout,
|
203 |
+
relu_dropout,
|
204 |
+
):
|
205 |
+
args = argparse.Namespace()
|
206 |
+
args.encoder_embed_dim = None
|
207 |
+
args.decoder_embed_dim = input_dim
|
208 |
+
args.decoder_attention_heads = num_heads
|
209 |
+
args.attention_dropout = attention_dropout
|
210 |
+
args.dropout = dropout
|
211 |
+
args.activation_dropout = relu_dropout
|
212 |
+
args.decoder_normalize_before = normalize_before
|
213 |
+
args.decoder_ffn_embed_dim = ffn_dim
|
214 |
+
return args
|
215 |
+
|
216 |
+
|
217 |
+
class VGGTransformerEncoder(FairseqEncoder):
|
218 |
+
"""VGG + Transformer encoder"""
|
219 |
+
|
220 |
+
def __init__(
|
221 |
+
self,
|
222 |
+
input_feat_per_channel,
|
223 |
+
vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG,
|
224 |
+
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG,
|
225 |
+
encoder_output_dim=512,
|
226 |
+
in_channels=1,
|
227 |
+
transformer_context=None,
|
228 |
+
transformer_sampling=None,
|
229 |
+
):
|
230 |
+
"""constructor for VGGTransformerEncoder
|
231 |
+
|
232 |
+
Args:
|
233 |
+
- input_feat_per_channel: feature dim (not including stacked,
|
234 |
+
just base feature)
|
235 |
+
- in_channel: # input channels (e.g., if stack 8 feature vector
|
236 |
+
together, this is 8)
|
237 |
+
- vggblock_config: configuration of vggblock, see comments on
|
238 |
+
DEFAULT_ENC_VGGBLOCK_CONFIG
|
239 |
+
- transformer_config: configuration of transformer layer, see comments
|
240 |
+
on DEFAULT_ENC_TRANSFORMER_CONFIG
|
241 |
+
- encoder_output_dim: final transformer output embedding dimension
|
242 |
+
- transformer_context: (left, right) if set, self-attention will be focused
|
243 |
+
on (t-left, t+right)
|
244 |
+
- transformer_sampling: an iterable of int, must match with
|
245 |
+
len(transformer_config), transformer_sampling[i] indicates sampling
|
246 |
+
factor for i-th transformer layer, after multihead att and feedfoward
|
247 |
+
part
|
248 |
+
"""
|
249 |
+
super().__init__(None)
|
250 |
+
|
251 |
+
self.num_vggblocks = 0
|
252 |
+
if vggblock_config is not None:
|
253 |
+
if not isinstance(vggblock_config, Iterable):
|
254 |
+
raise ValueError("vggblock_config is not iterable")
|
255 |
+
self.num_vggblocks = len(vggblock_config)
|
256 |
+
|
257 |
+
self.conv_layers = nn.ModuleList()
|
258 |
+
self.in_channels = in_channels
|
259 |
+
self.input_dim = input_feat_per_channel
|
260 |
+
self.pooling_kernel_sizes = []
|
261 |
+
|
262 |
+
if vggblock_config is not None:
|
263 |
+
for _, config in enumerate(vggblock_config):
|
264 |
+
(
|
265 |
+
out_channels,
|
266 |
+
conv_kernel_size,
|
267 |
+
pooling_kernel_size,
|
268 |
+
num_conv_layers,
|
269 |
+
layer_norm,
|
270 |
+
) = config
|
271 |
+
self.conv_layers.append(
|
272 |
+
VGGBlock(
|
273 |
+
in_channels,
|
274 |
+
out_channels,
|
275 |
+
conv_kernel_size,
|
276 |
+
pooling_kernel_size,
|
277 |
+
num_conv_layers,
|
278 |
+
input_dim=input_feat_per_channel,
|
279 |
+
layer_norm=layer_norm,
|
280 |
+
)
|
281 |
+
)
|
282 |
+
self.pooling_kernel_sizes.append(pooling_kernel_size)
|
283 |
+
in_channels = out_channels
|
284 |
+
input_feat_per_channel = self.conv_layers[-1].output_dim
|
285 |
+
|
286 |
+
transformer_input_dim = self.infer_conv_output_dim(
|
287 |
+
self.in_channels, self.input_dim
|
288 |
+
)
|
289 |
+
# transformer_input_dim is the output dimension of VGG part
|
290 |
+
|
291 |
+
self.validate_transformer_config(transformer_config)
|
292 |
+
self.transformer_context = self.parse_transformer_context(transformer_context)
|
293 |
+
self.transformer_sampling = self.parse_transformer_sampling(
|
294 |
+
transformer_sampling, len(transformer_config)
|
295 |
+
)
|
296 |
+
|
297 |
+
self.transformer_layers = nn.ModuleList()
|
298 |
+
|
299 |
+
if transformer_input_dim != transformer_config[0][0]:
|
300 |
+
self.transformer_layers.append(
|
301 |
+
Linear(transformer_input_dim, transformer_config[0][0])
|
302 |
+
)
|
303 |
+
self.transformer_layers.append(
|
304 |
+
TransformerEncoderLayer(
|
305 |
+
prepare_transformer_encoder_params(*transformer_config[0])
|
306 |
+
)
|
307 |
+
)
|
308 |
+
|
309 |
+
for i in range(1, len(transformer_config)):
|
310 |
+
if transformer_config[i - 1][0] != transformer_config[i][0]:
|
311 |
+
self.transformer_layers.append(
|
312 |
+
Linear(transformer_config[i - 1][0], transformer_config[i][0])
|
313 |
+
)
|
314 |
+
self.transformer_layers.append(
|
315 |
+
TransformerEncoderLayer(
|
316 |
+
prepare_transformer_encoder_params(*transformer_config[i])
|
317 |
+
)
|
318 |
+
)
|
319 |
+
|
320 |
+
self.encoder_output_dim = encoder_output_dim
|
321 |
+
self.transformer_layers.extend(
|
322 |
+
[
|
323 |
+
Linear(transformer_config[-1][0], encoder_output_dim),
|
324 |
+
LayerNorm(encoder_output_dim),
|
325 |
+
]
|
326 |
+
)
|
327 |
+
|
328 |
+
def forward(self, src_tokens, src_lengths, **kwargs):
|
329 |
+
"""
|
330 |
+
src_tokens: padded tensor (B, T, C * feat)
|
331 |
+
src_lengths: tensor of original lengths of input utterances (B,)
|
332 |
+
"""
|
333 |
+
bsz, max_seq_len, _ = src_tokens.size()
|
334 |
+
x = src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim)
|
335 |
+
x = x.transpose(1, 2).contiguous()
|
336 |
+
# (B, C, T, feat)
|
337 |
+
|
338 |
+
for layer_idx in range(len(self.conv_layers)):
|
339 |
+
x = self.conv_layers[layer_idx](x)
|
340 |
+
|
341 |
+
bsz, _, output_seq_len, _ = x.size()
|
342 |
+
|
343 |
+
# (B, C, T, feat) -> (B, T, C, feat) -> (T, B, C, feat) -> (T, B, C * feat)
|
344 |
+
x = x.transpose(1, 2).transpose(0, 1)
|
345 |
+
x = x.contiguous().view(output_seq_len, bsz, -1)
|
346 |
+
|
347 |
+
input_lengths = src_lengths.clone()
|
348 |
+
for s in self.pooling_kernel_sizes:
|
349 |
+
input_lengths = (input_lengths.float() / s).ceil().long()
|
350 |
+
|
351 |
+
encoder_padding_mask, _ = lengths_to_encoder_padding_mask(
|
352 |
+
input_lengths, batch_first=True
|
353 |
+
)
|
354 |
+
if not encoder_padding_mask.any():
|
355 |
+
encoder_padding_mask = None
|
356 |
+
|
357 |
+
subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5)
|
358 |
+
attn_mask = self.lengths_to_attn_mask(input_lengths, subsampling_factor)
|
359 |
+
|
360 |
+
transformer_layer_idx = 0
|
361 |
+
|
362 |
+
for layer_idx in range(len(self.transformer_layers)):
|
363 |
+
|
364 |
+
if isinstance(self.transformer_layers[layer_idx], TransformerEncoderLayer):
|
365 |
+
x = self.transformer_layers[layer_idx](
|
366 |
+
x, encoder_padding_mask, attn_mask
|
367 |
+
)
|
368 |
+
|
369 |
+
if self.transformer_sampling[transformer_layer_idx] != 1:
|
370 |
+
sampling_factor = self.transformer_sampling[transformer_layer_idx]
|
371 |
+
x, encoder_padding_mask, attn_mask = self.slice(
|
372 |
+
x, encoder_padding_mask, attn_mask, sampling_factor
|
373 |
+
)
|
374 |
+
|
375 |
+
transformer_layer_idx += 1
|
376 |
+
|
377 |
+
else:
|
378 |
+
x = self.transformer_layers[layer_idx](x)
|
379 |
+
|
380 |
+
# encoder_padding_maks is a (T x B) tensor, its [t, b] elements indicate
|
381 |
+
# whether encoder_output[t, b] is valid or not (valid=0, invalid=1)
|
382 |
+
|
383 |
+
return {
|
384 |
+
"encoder_out": x, # (T, B, C)
|
385 |
+
"encoder_padding_mask": encoder_padding_mask.t()
|
386 |
+
if encoder_padding_mask is not None
|
387 |
+
else None,
|
388 |
+
# (B, T) --> (T, B)
|
389 |
+
}
|
390 |
+
|
391 |
+
def infer_conv_output_dim(self, in_channels, input_dim):
|
392 |
+
sample_seq_len = 200
|
393 |
+
sample_bsz = 10
|
394 |
+
x = torch.randn(sample_bsz, in_channels, sample_seq_len, input_dim)
|
395 |
+
for i, _ in enumerate(self.conv_layers):
|
396 |
+
x = self.conv_layers[i](x)
|
397 |
+
x = x.transpose(1, 2)
|
398 |
+
mb, seq = x.size()[:2]
|
399 |
+
return x.contiguous().view(mb, seq, -1).size(-1)
|
400 |
+
|
401 |
+
def validate_transformer_config(self, transformer_config):
|
402 |
+
for config in transformer_config:
|
403 |
+
input_dim, num_heads = config[:2]
|
404 |
+
if input_dim % num_heads != 0:
|
405 |
+
msg = (
|
406 |
+
"ERROR in transformer config {}: ".format(config)
|
407 |
+
+ "input dimension {} ".format(input_dim)
|
408 |
+
+ "not dividable by number of heads {}".format(num_heads)
|
409 |
+
)
|
410 |
+
raise ValueError(msg)
|
411 |
+
|
412 |
+
def parse_transformer_context(self, transformer_context):
|
413 |
+
"""
|
414 |
+
transformer_context can be the following:
|
415 |
+
- None; indicates no context is used, i.e.,
|
416 |
+
transformer can access full context
|
417 |
+
- a tuple/list of two int; indicates left and right context,
|
418 |
+
any number <0 indicates infinite context
|
419 |
+
* e.g., (5, 6) indicates that for query at x_t, transformer can
|
420 |
+
access [t-5, t+6] (inclusive)
|
421 |
+
* e.g., (-1, 6) indicates that for query at x_t, transformer can
|
422 |
+
access [0, t+6] (inclusive)
|
423 |
+
"""
|
424 |
+
if transformer_context is None:
|
425 |
+
return None
|
426 |
+
|
427 |
+
if not isinstance(transformer_context, Iterable):
|
428 |
+
raise ValueError("transformer context must be Iterable if it is not None")
|
429 |
+
|
430 |
+
if len(transformer_context) != 2:
|
431 |
+
raise ValueError("transformer context must have length 2")
|
432 |
+
|
433 |
+
left_context = transformer_context[0]
|
434 |
+
if left_context < 0:
|
435 |
+
left_context = None
|
436 |
+
|
437 |
+
right_context = transformer_context[1]
|
438 |
+
if right_context < 0:
|
439 |
+
right_context = None
|
440 |
+
|
441 |
+
if left_context is None and right_context is None:
|
442 |
+
return None
|
443 |
+
|
444 |
+
return (left_context, right_context)
|
445 |
+
|
446 |
+
def parse_transformer_sampling(self, transformer_sampling, num_layers):
|
447 |
+
"""
|
448 |
+
parsing transformer sampling configuration
|
449 |
+
|
450 |
+
Args:
|
451 |
+
- transformer_sampling, accepted input:
|
452 |
+
* None, indicating no sampling
|
453 |
+
* an Iterable with int (>0) as element
|
454 |
+
- num_layers, expected number of transformer layers, must match with
|
455 |
+
the length of transformer_sampling if it is not None
|
456 |
+
|
457 |
+
Returns:
|
458 |
+
- A tuple with length num_layers
|
459 |
+
"""
|
460 |
+
if transformer_sampling is None:
|
461 |
+
return (1,) * num_layers
|
462 |
+
|
463 |
+
if not isinstance(transformer_sampling, Iterable):
|
464 |
+
raise ValueError(
|
465 |
+
"transformer_sampling must be an iterable if it is not None"
|
466 |
+
)
|
467 |
+
|
468 |
+
if len(transformer_sampling) != num_layers:
|
469 |
+
raise ValueError(
|
470 |
+
"transformer_sampling {} does not match with the number "
|
471 |
+
"of layers {}".format(transformer_sampling, num_layers)
|
472 |
+
)
|
473 |
+
|
474 |
+
for layer, value in enumerate(transformer_sampling):
|
475 |
+
if not isinstance(value, int):
|
476 |
+
raise ValueError("Invalid value in transformer_sampling: ")
|
477 |
+
if value < 1:
|
478 |
+
raise ValueError(
|
479 |
+
"{} layer's subsampling is {}.".format(layer, value)
|
480 |
+
+ " This is not allowed! "
|
481 |
+
)
|
482 |
+
return transformer_sampling
|
483 |
+
|
484 |
+
def slice(self, embedding, padding_mask, attn_mask, sampling_factor):
|
485 |
+
"""
|
486 |
+
embedding is a (T, B, D) tensor
|
487 |
+
padding_mask is a (B, T) tensor or None
|
488 |
+
attn_mask is a (T, T) tensor or None
|
489 |
+
"""
|
490 |
+
embedding = embedding[::sampling_factor, :, :]
|
491 |
+
if padding_mask is not None:
|
492 |
+
padding_mask = padding_mask[:, ::sampling_factor]
|
493 |
+
if attn_mask is not None:
|
494 |
+
attn_mask = attn_mask[::sampling_factor, ::sampling_factor]
|
495 |
+
|
496 |
+
return embedding, padding_mask, attn_mask
|
497 |
+
|
498 |
+
def lengths_to_attn_mask(self, input_lengths, subsampling_factor=1):
|
499 |
+
"""
|
500 |
+
create attention mask according to sequence lengths and transformer
|
501 |
+
context
|
502 |
+
|
503 |
+
Args:
|
504 |
+
- input_lengths: (B, )-shape Int/Long tensor; input_lengths[b] is
|
505 |
+
the length of b-th sequence
|
506 |
+
- subsampling_factor: int
|
507 |
+
* Note that the left_context and right_context is specified in
|
508 |
+
the input frame-level while input to transformer may already
|
509 |
+
go through subsampling (e.g., the use of striding in vggblock)
|
510 |
+
we use subsampling_factor to scale the left/right context
|
511 |
+
|
512 |
+
Return:
|
513 |
+
- a (T, T) binary tensor or None, where T is max(input_lengths)
|
514 |
+
* if self.transformer_context is None, None
|
515 |
+
* if left_context is None,
|
516 |
+
* attn_mask[t, t + right_context + 1:] = 1
|
517 |
+
* others = 0
|
518 |
+
* if right_context is None,
|
519 |
+
* attn_mask[t, 0:t - left_context] = 1
|
520 |
+
* others = 0
|
521 |
+
* elsif
|
522 |
+
* attn_mask[t, t - left_context: t + right_context + 1] = 0
|
523 |
+
* others = 1
|
524 |
+
"""
|
525 |
+
if self.transformer_context is None:
|
526 |
+
return None
|
527 |
+
|
528 |
+
maxT = torch.max(input_lengths).item()
|
529 |
+
attn_mask = torch.zeros(maxT, maxT)
|
530 |
+
|
531 |
+
left_context = self.transformer_context[0]
|
532 |
+
right_context = self.transformer_context[1]
|
533 |
+
if left_context is not None:
|
534 |
+
left_context = math.ceil(self.transformer_context[0] / subsampling_factor)
|
535 |
+
if right_context is not None:
|
536 |
+
right_context = math.ceil(self.transformer_context[1] / subsampling_factor)
|
537 |
+
|
538 |
+
for t in range(maxT):
|
539 |
+
if left_context is not None:
|
540 |
+
st = 0
|
541 |
+
en = max(st, t - left_context)
|
542 |
+
attn_mask[t, st:en] = 1
|
543 |
+
if right_context is not None:
|
544 |
+
st = t + right_context + 1
|
545 |
+
st = min(st, maxT - 1)
|
546 |
+
attn_mask[t, st:] = 1
|
547 |
+
|
548 |
+
return attn_mask.to(input_lengths.device)
|
549 |
+
|
550 |
+
def reorder_encoder_out(self, encoder_out, new_order):
|
551 |
+
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
|
552 |
+
1, new_order
|
553 |
+
)
|
554 |
+
if encoder_out["encoder_padding_mask"] is not None:
|
555 |
+
encoder_out["encoder_padding_mask"] = encoder_out[
|
556 |
+
"encoder_padding_mask"
|
557 |
+
].index_select(1, new_order)
|
558 |
+
return encoder_out
|
559 |
+
|
560 |
+
|
561 |
+
class TransformerDecoder(FairseqIncrementalDecoder):
|
562 |
+
"""
|
563 |
+
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
|
564 |
+
is a :class:`TransformerDecoderLayer`.
|
565 |
+
Args:
|
566 |
+
args (argparse.Namespace): parsed command-line arguments
|
567 |
+
dictionary (~fairseq.data.Dictionary): decoding dictionary
|
568 |
+
embed_tokens (torch.nn.Embedding): output embedding
|
569 |
+
no_encoder_attn (bool, optional): whether to attend to encoder outputs.
|
570 |
+
Default: ``False``
|
571 |
+
left_pad (bool, optional): whether the input is left-padded. Default:
|
572 |
+
``False``
|
573 |
+
"""
|
574 |
+
|
575 |
+
def __init__(
|
576 |
+
self,
|
577 |
+
dictionary,
|
578 |
+
embed_dim=512,
|
579 |
+
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG,
|
580 |
+
conv_config=DEFAULT_DEC_CONV_CONFIG,
|
581 |
+
encoder_output_dim=512,
|
582 |
+
):
|
583 |
+
|
584 |
+
super().__init__(dictionary)
|
585 |
+
vocab_size = len(dictionary)
|
586 |
+
self.padding_idx = dictionary.pad()
|
587 |
+
self.embed_tokens = Embedding(vocab_size, embed_dim, self.padding_idx)
|
588 |
+
|
589 |
+
self.conv_layers = nn.ModuleList()
|
590 |
+
for i in range(len(conv_config)):
|
591 |
+
out_channels, kernel_size, layer_norm = conv_config[i]
|
592 |
+
if i == 0:
|
593 |
+
conv_layer = LinearizedConv1d(
|
594 |
+
embed_dim, out_channels, kernel_size, padding=kernel_size - 1
|
595 |
+
)
|
596 |
+
else:
|
597 |
+
conv_layer = LinearizedConv1d(
|
598 |
+
conv_config[i - 1][0],
|
599 |
+
out_channels,
|
600 |
+
kernel_size,
|
601 |
+
padding=kernel_size - 1,
|
602 |
+
)
|
603 |
+
self.conv_layers.append(conv_layer)
|
604 |
+
if layer_norm:
|
605 |
+
self.conv_layers.append(nn.LayerNorm(out_channels))
|
606 |
+
self.conv_layers.append(nn.ReLU())
|
607 |
+
|
608 |
+
self.layers = nn.ModuleList()
|
609 |
+
if conv_config[-1][0] != transformer_config[0][0]:
|
610 |
+
self.layers.append(Linear(conv_config[-1][0], transformer_config[0][0]))
|
611 |
+
self.layers.append(
|
612 |
+
TransformerDecoderLayer(
|
613 |
+
prepare_transformer_decoder_params(*transformer_config[0])
|
614 |
+
)
|
615 |
+
)
|
616 |
+
|
617 |
+
for i in range(1, len(transformer_config)):
|
618 |
+
if transformer_config[i - 1][0] != transformer_config[i][0]:
|
619 |
+
self.layers.append(
|
620 |
+
Linear(transformer_config[i - 1][0], transformer_config[i][0])
|
621 |
+
)
|
622 |
+
self.layers.append(
|
623 |
+
TransformerDecoderLayer(
|
624 |
+
prepare_transformer_decoder_params(*transformer_config[i])
|
625 |
+
)
|
626 |
+
)
|
627 |
+
self.fc_out = Linear(transformer_config[-1][0], vocab_size)
|
628 |
+
|
629 |
+
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
|
630 |
+
"""
|
631 |
+
Args:
|
632 |
+
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
633 |
+
`(batch, tgt_len)`, for input feeding/teacher forcing
|
634 |
+
encoder_out (Tensor, optional): output from the encoder, used for
|
635 |
+
encoder-side attention
|
636 |
+
incremental_state (dict): dictionary used for storing state during
|
637 |
+
:ref:`Incremental decoding`
|
638 |
+
Returns:
|
639 |
+
tuple:
|
640 |
+
- the last decoder layer's output of shape `(batch, tgt_len,
|
641 |
+
vocab)`
|
642 |
+
- the last decoder layer's attention weights of shape `(batch,
|
643 |
+
tgt_len, src_len)`
|
644 |
+
"""
|
645 |
+
target_padding_mask = (
|
646 |
+
(prev_output_tokens == self.padding_idx).to(prev_output_tokens.device)
|
647 |
+
if incremental_state is None
|
648 |
+
else None
|
649 |
+
)
|
650 |
+
|
651 |
+
if incremental_state is not None:
|
652 |
+
prev_output_tokens = prev_output_tokens[:, -1:]
|
653 |
+
|
654 |
+
# embed tokens
|
655 |
+
x = self.embed_tokens(prev_output_tokens)
|
656 |
+
|
657 |
+
# B x T x C -> T x B x C
|
658 |
+
x = self._transpose_if_training(x, incremental_state)
|
659 |
+
|
660 |
+
for layer in self.conv_layers:
|
661 |
+
if isinstance(layer, LinearizedConvolution):
|
662 |
+
x = layer(x, incremental_state)
|
663 |
+
else:
|
664 |
+
x = layer(x)
|
665 |
+
|
666 |
+
# B x T x C -> T x B x C
|
667 |
+
x = self._transpose_if_inference(x, incremental_state)
|
668 |
+
|
669 |
+
# decoder layers
|
670 |
+
for layer in self.layers:
|
671 |
+
if isinstance(layer, TransformerDecoderLayer):
|
672 |
+
x, *_ = layer(
|
673 |
+
x,
|
674 |
+
(encoder_out["encoder_out"] if encoder_out is not None else None),
|
675 |
+
(
|
676 |
+
encoder_out["encoder_padding_mask"].t()
|
677 |
+
if encoder_out["encoder_padding_mask"] is not None
|
678 |
+
else None
|
679 |
+
),
|
680 |
+
incremental_state,
|
681 |
+
self_attn_mask=(
|
682 |
+
self.buffered_future_mask(x)
|
683 |
+
if incremental_state is None
|
684 |
+
else None
|
685 |
+
),
|
686 |
+
self_attn_padding_mask=(
|
687 |
+
target_padding_mask if incremental_state is None else None
|
688 |
+
),
|
689 |
+
)
|
690 |
+
else:
|
691 |
+
x = layer(x)
|
692 |
+
|
693 |
+
# T x B x C -> B x T x C
|
694 |
+
x = x.transpose(0, 1)
|
695 |
+
|
696 |
+
x = self.fc_out(x)
|
697 |
+
|
698 |
+
return x, None
|
699 |
+
|
700 |
+
def buffered_future_mask(self, tensor):
|
701 |
+
dim = tensor.size(0)
|
702 |
+
if (
|
703 |
+
not hasattr(self, "_future_mask")
|
704 |
+
or self._future_mask is None
|
705 |
+
or self._future_mask.device != tensor.device
|
706 |
+
):
|
707 |
+
self._future_mask = torch.triu(
|
708 |
+
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
|
709 |
+
)
|
710 |
+
if self._future_mask.size(0) < dim:
|
711 |
+
self._future_mask = torch.triu(
|
712 |
+
utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1
|
713 |
+
)
|
714 |
+
return self._future_mask[:dim, :dim]
|
715 |
+
|
716 |
+
def _transpose_if_training(self, x, incremental_state):
|
717 |
+
if incremental_state is None:
|
718 |
+
x = x.transpose(0, 1)
|
719 |
+
return x
|
720 |
+
|
721 |
+
def _transpose_if_inference(self, x, incremental_state):
|
722 |
+
if incremental_state:
|
723 |
+
x = x.transpose(0, 1)
|
724 |
+
return x
|
725 |
+
|
726 |
+
|
727 |
+
@register_model("asr_vggtransformer_encoder")
|
728 |
+
class VGGTransformerEncoderModel(FairseqEncoderModel):
|
729 |
+
def __init__(self, encoder):
|
730 |
+
super().__init__(encoder)
|
731 |
+
|
732 |
+
@staticmethod
|
733 |
+
def add_args(parser):
|
734 |
+
"""Add model-specific arguments to the parser."""
|
735 |
+
parser.add_argument(
|
736 |
+
"--input-feat-per-channel",
|
737 |
+
type=int,
|
738 |
+
metavar="N",
|
739 |
+
help="encoder input dimension per input channel",
|
740 |
+
)
|
741 |
+
parser.add_argument(
|
742 |
+
"--vggblock-enc-config",
|
743 |
+
type=str,
|
744 |
+
metavar="EXPR",
|
745 |
+
help="""
|
746 |
+
an array of tuples each containing the configuration of one vggblock
|
747 |
+
[(out_channels, conv_kernel_size, pooling_kernel_size,num_conv_layers), ...]
|
748 |
+
""",
|
749 |
+
)
|
750 |
+
parser.add_argument(
|
751 |
+
"--transformer-enc-config",
|
752 |
+
type=str,
|
753 |
+
metavar="EXPR",
|
754 |
+
help="""
|
755 |
+
a tuple containing the configuration of the Transformer layers
|
756 |
+
configurations:
|
757 |
+
[(input_dim,
|
758 |
+
num_heads,
|
759 |
+
ffn_dim,
|
760 |
+
normalize_before,
|
761 |
+
dropout,
|
762 |
+
attention_dropout,
|
763 |
+
relu_dropout), ]""",
|
764 |
+
)
|
765 |
+
parser.add_argument(
|
766 |
+
"--enc-output-dim",
|
767 |
+
type=int,
|
768 |
+
metavar="N",
|
769 |
+
help="encoder output dimension, projecting the LSTM output",
|
770 |
+
)
|
771 |
+
parser.add_argument(
|
772 |
+
"--in-channels",
|
773 |
+
type=int,
|
774 |
+
metavar="N",
|
775 |
+
help="number of encoder input channels",
|
776 |
+
)
|
777 |
+
parser.add_argument(
|
778 |
+
"--transformer-context",
|
779 |
+
type=str,
|
780 |
+
metavar="EXPR",
|
781 |
+
help="""
|
782 |
+
either None or a tuple of two ints, indicating left/right context a
|
783 |
+
transformer can have access to""",
|
784 |
+
)
|
785 |
+
parser.add_argument(
|
786 |
+
"--transformer-sampling",
|
787 |
+
type=str,
|
788 |
+
metavar="EXPR",
|
789 |
+
help="""
|
790 |
+
either None or a tuple of ints, indicating sampling factor in each layer""",
|
791 |
+
)
|
792 |
+
|
793 |
+
@classmethod
|
794 |
+
def build_model(cls, args, task):
|
795 |
+
"""Build a new model instance."""
|
796 |
+
base_architecture_enconly(args)
|
797 |
+
encoder = VGGTransformerEncoderOnly(
|
798 |
+
vocab_size=len(task.target_dictionary),
|
799 |
+
input_feat_per_channel=args.input_feat_per_channel,
|
800 |
+
vggblock_config=eval(args.vggblock_enc_config),
|
801 |
+
transformer_config=eval(args.transformer_enc_config),
|
802 |
+
encoder_output_dim=args.enc_output_dim,
|
803 |
+
in_channels=args.in_channels,
|
804 |
+
transformer_context=eval(args.transformer_context),
|
805 |
+
transformer_sampling=eval(args.transformer_sampling),
|
806 |
+
)
|
807 |
+
return cls(encoder)
|
808 |
+
|
809 |
+
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
810 |
+
# net_output['encoder_out'] is a (T, B, D) tensor
|
811 |
+
lprobs = super().get_normalized_probs(net_output, log_probs, sample)
|
812 |
+
# lprobs is a (T, B, D) tensor
|
813 |
+
# we need to transoose to get (B, T, D) tensor
|
814 |
+
lprobs = lprobs.transpose(0, 1).contiguous()
|
815 |
+
lprobs.batch_first = True
|
816 |
+
return lprobs
|
817 |
+
|
818 |
+
|
819 |
+
class VGGTransformerEncoderOnly(VGGTransformerEncoder):
|
820 |
+
def __init__(
|
821 |
+
self,
|
822 |
+
vocab_size,
|
823 |
+
input_feat_per_channel,
|
824 |
+
vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG,
|
825 |
+
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG,
|
826 |
+
encoder_output_dim=512,
|
827 |
+
in_channels=1,
|
828 |
+
transformer_context=None,
|
829 |
+
transformer_sampling=None,
|
830 |
+
):
|
831 |
+
super().__init__(
|
832 |
+
input_feat_per_channel=input_feat_per_channel,
|
833 |
+
vggblock_config=vggblock_config,
|
834 |
+
transformer_config=transformer_config,
|
835 |
+
encoder_output_dim=encoder_output_dim,
|
836 |
+
in_channels=in_channels,
|
837 |
+
transformer_context=transformer_context,
|
838 |
+
transformer_sampling=transformer_sampling,
|
839 |
+
)
|
840 |
+
self.fc_out = Linear(self.encoder_output_dim, vocab_size)
|
841 |
+
|
842 |
+
def forward(self, src_tokens, src_lengths, **kwargs):
|
843 |
+
"""
|
844 |
+
src_tokens: padded tensor (B, T, C * feat)
|
845 |
+
src_lengths: tensor of original lengths of input utterances (B,)
|
846 |
+
"""
|
847 |
+
|
848 |
+
enc_out = super().forward(src_tokens, src_lengths)
|
849 |
+
x = self.fc_out(enc_out["encoder_out"])
|
850 |
+
# x = F.log_softmax(x, dim=-1)
|
851 |
+
# Note: no need this line, because model.get_normalized_prob will call
|
852 |
+
# log_softmax
|
853 |
+
return {
|
854 |
+
"encoder_out": x, # (T, B, C)
|
855 |
+
"encoder_padding_mask": enc_out["encoder_padding_mask"], # (T, B)
|
856 |
+
}
|
857 |
+
|
858 |
+
def max_positions(self):
|
859 |
+
"""Maximum input length supported by the encoder."""
|
860 |
+
return (1e6, 1e6) # an arbitrary large number
|
861 |
+
|
862 |
+
|
863 |
+
def Embedding(num_embeddings, embedding_dim, padding_idx):
|
864 |
+
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
865 |
+
# nn.init.uniform_(m.weight, -0.1, 0.1)
|
866 |
+
# nn.init.constant_(m.weight[padding_idx], 0)
|
867 |
+
return m
|
868 |
+
|
869 |
+
|
870 |
+
def Linear(in_features, out_features, bias=True, dropout=0):
|
871 |
+
"""Linear layer (input: N x T x C)"""
|
872 |
+
m = nn.Linear(in_features, out_features, bias=bias)
|
873 |
+
# m.weight.data.uniform_(-0.1, 0.1)
|
874 |
+
# if bias:
|
875 |
+
# m.bias.data.uniform_(-0.1, 0.1)
|
876 |
+
return m
|
877 |
+
|
878 |
+
|
879 |
+
def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
|
880 |
+
"""Weight-normalized Conv1d layer optimized for decoding"""
|
881 |
+
m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs)
|
882 |
+
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
|
883 |
+
nn.init.normal_(m.weight, mean=0, std=std)
|
884 |
+
nn.init.constant_(m.bias, 0)
|
885 |
+
return nn.utils.weight_norm(m, dim=2)
|
886 |
+
|
887 |
+
|
888 |
+
def LayerNorm(embedding_dim):
|
889 |
+
m = nn.LayerNorm(embedding_dim)
|
890 |
+
return m
|
891 |
+
|
892 |
+
|
893 |
+
# seq2seq models
|
894 |
+
def base_architecture(args):
|
895 |
+
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 40)
|
896 |
+
args.vggblock_enc_config = getattr(
|
897 |
+
args, "vggblock_enc_config", DEFAULT_ENC_VGGBLOCK_CONFIG
|
898 |
+
)
|
899 |
+
args.transformer_enc_config = getattr(
|
900 |
+
args, "transformer_enc_config", DEFAULT_ENC_TRANSFORMER_CONFIG
|
901 |
+
)
|
902 |
+
args.enc_output_dim = getattr(args, "enc_output_dim", 512)
|
903 |
+
args.in_channels = getattr(args, "in_channels", 1)
|
904 |
+
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128)
|
905 |
+
args.transformer_dec_config = getattr(
|
906 |
+
args, "transformer_dec_config", DEFAULT_ENC_TRANSFORMER_CONFIG
|
907 |
+
)
|
908 |
+
args.conv_dec_config = getattr(args, "conv_dec_config", DEFAULT_DEC_CONV_CONFIG)
|
909 |
+
args.transformer_context = getattr(args, "transformer_context", "None")
|
910 |
+
|
911 |
+
|
912 |
+
@register_model_architecture("asr_vggtransformer", "vggtransformer_1")
|
913 |
+
def vggtransformer_1(args):
|
914 |
+
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
|
915 |
+
args.vggblock_enc_config = getattr(
|
916 |
+
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
|
917 |
+
)
|
918 |
+
args.transformer_enc_config = getattr(
|
919 |
+
args,
|
920 |
+
"transformer_enc_config",
|
921 |
+
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 14",
|
922 |
+
)
|
923 |
+
args.enc_output_dim = getattr(args, "enc_output_dim", 1024)
|
924 |
+
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128)
|
925 |
+
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4")
|
926 |
+
args.transformer_dec_config = getattr(
|
927 |
+
args,
|
928 |
+
"transformer_dec_config",
|
929 |
+
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 4",
|
930 |
+
)
|
931 |
+
|
932 |
+
|
933 |
+
@register_model_architecture("asr_vggtransformer", "vggtransformer_2")
|
934 |
+
def vggtransformer_2(args):
|
935 |
+
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
|
936 |
+
args.vggblock_enc_config = getattr(
|
937 |
+
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
|
938 |
+
)
|
939 |
+
args.transformer_enc_config = getattr(
|
940 |
+
args,
|
941 |
+
"transformer_enc_config",
|
942 |
+
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16",
|
943 |
+
)
|
944 |
+
args.enc_output_dim = getattr(args, "enc_output_dim", 1024)
|
945 |
+
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512)
|
946 |
+
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4")
|
947 |
+
args.transformer_dec_config = getattr(
|
948 |
+
args,
|
949 |
+
"transformer_dec_config",
|
950 |
+
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 6",
|
951 |
+
)
|
952 |
+
|
953 |
+
|
954 |
+
@register_model_architecture("asr_vggtransformer", "vggtransformer_base")
|
955 |
+
def vggtransformer_base(args):
|
956 |
+
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
|
957 |
+
args.vggblock_enc_config = getattr(
|
958 |
+
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
|
959 |
+
)
|
960 |
+
args.transformer_enc_config = getattr(
|
961 |
+
args, "transformer_enc_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 12"
|
962 |
+
)
|
963 |
+
|
964 |
+
args.enc_output_dim = getattr(args, "enc_output_dim", 512)
|
965 |
+
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512)
|
966 |
+
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4")
|
967 |
+
args.transformer_dec_config = getattr(
|
968 |
+
args, "transformer_dec_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 6"
|
969 |
+
)
|
970 |
+
# Size estimations:
|
971 |
+
# Encoder:
|
972 |
+
# - vggblock param: 64*1*3*3 + 64*64*3*3 + 128*64*3*3 + 128*128*3 = 258K
|
973 |
+
# Transformer:
|
974 |
+
# - input dimension adapter: 2560 x 512 -> 1.31M
|
975 |
+
# - transformer_layers (x12) --> 37.74M
|
976 |
+
# * MultiheadAttention: 512*512*3 (in_proj) + 512*512 (out_proj) = 1.048M
|
977 |
+
# * FFN weight: 512*2048*2 = 2.097M
|
978 |
+
# - output dimension adapter: 512 x 512 -> 0.26 M
|
979 |
+
# Decoder:
|
980 |
+
# - LinearizedConv1d: 512 * 256 * 3 + 256 * 256 * 3 * 3
|
981 |
+
# - transformer_layer: (x6) --> 25.16M
|
982 |
+
# * MultiheadAttention (self-attention): 512*512*3 + 512*512 = 1.048M
|
983 |
+
# * MultiheadAttention (encoder-attention): 512*512*3 + 512*512 = 1.048M
|
984 |
+
# * FFN: 512*2048*2 = 2.097M
|
985 |
+
# Final FC:
|
986 |
+
# - FC: 512*5000 = 256K (assuming vocab size 5K)
|
987 |
+
# In total:
|
988 |
+
# ~65 M
|
989 |
+
|
990 |
+
|
991 |
+
# CTC models
|
992 |
+
def base_architecture_enconly(args):
|
993 |
+
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 40)
|
994 |
+
args.vggblock_enc_config = getattr(
|
995 |
+
args, "vggblock_enc_config", "[(32, 3, 2, 2, True)] * 2"
|
996 |
+
)
|
997 |
+
args.transformer_enc_config = getattr(
|
998 |
+
args, "transformer_enc_config", "((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2"
|
999 |
+
)
|
1000 |
+
args.enc_output_dim = getattr(args, "enc_output_dim", 512)
|
1001 |
+
args.in_channels = getattr(args, "in_channels", 1)
|
1002 |
+
args.transformer_context = getattr(args, "transformer_context", "None")
|
1003 |
+
args.transformer_sampling = getattr(args, "transformer_sampling", "None")
|
1004 |
+
|
1005 |
+
|
1006 |
+
@register_model_architecture("asr_vggtransformer_encoder", "vggtransformer_enc_1")
|
1007 |
+
def vggtransformer_enc_1(args):
|
1008 |
+
# vggtransformer_1 is the same as vggtransformer_enc_big, except the number
|
1009 |
+
# of layers is increased to 16
|
1010 |
+
# keep it here for backward compatiablity purpose
|
1011 |
+
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
|
1012 |
+
args.vggblock_enc_config = getattr(
|
1013 |
+
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
|
1014 |
+
)
|
1015 |
+
args.transformer_enc_config = getattr(
|
1016 |
+
args,
|
1017 |
+
"transformer_enc_config",
|
1018 |
+
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16",
|
1019 |
+
)
|
1020 |
+
args.enc_output_dim = getattr(args, "enc_output_dim", 1024)
|
fairseq/examples/speech_recognition/models/w2l_conv_glu_enc.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the MIT license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
import math
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from fairseq.models import (
|
14 |
+
FairseqEncoder,
|
15 |
+
FairseqEncoderModel,
|
16 |
+
register_model,
|
17 |
+
register_model_architecture,
|
18 |
+
)
|
19 |
+
from fairseq.modules.fairseq_dropout import FairseqDropout
|
20 |
+
|
21 |
+
|
22 |
+
default_conv_enc_config = """[
|
23 |
+
(400, 13, 170, 0.2),
|
24 |
+
(440, 14, 0, 0.214),
|
25 |
+
(484, 15, 0, 0.22898),
|
26 |
+
(532, 16, 0, 0.2450086),
|
27 |
+
(584, 17, 0, 0.262159202),
|
28 |
+
(642, 18, 0, 0.28051034614),
|
29 |
+
(706, 19, 0, 0.30014607037),
|
30 |
+
(776, 20, 0, 0.321156295296),
|
31 |
+
(852, 21, 0, 0.343637235966),
|
32 |
+
(936, 22, 0, 0.367691842484),
|
33 |
+
(1028, 23, 0, 0.393430271458),
|
34 |
+
(1130, 24, 0, 0.42097039046),
|
35 |
+
(1242, 25, 0, 0.450438317792),
|
36 |
+
(1366, 26, 0, 0.481969000038),
|
37 |
+
(1502, 27, 0, 0.51570683004),
|
38 |
+
(1652, 28, 0, 0.551806308143),
|
39 |
+
(1816, 29, 0, 0.590432749713),
|
40 |
+
]"""
|
41 |
+
|
42 |
+
|
43 |
+
@register_model("asr_w2l_conv_glu_encoder")
|
44 |
+
class W2lConvGluEncoderModel(FairseqEncoderModel):
|
45 |
+
def __init__(self, encoder):
|
46 |
+
super().__init__(encoder)
|
47 |
+
|
48 |
+
@staticmethod
|
49 |
+
def add_args(parser):
|
50 |
+
"""Add model-specific arguments to the parser."""
|
51 |
+
parser.add_argument(
|
52 |
+
"--input-feat-per-channel",
|
53 |
+
type=int,
|
54 |
+
metavar="N",
|
55 |
+
help="encoder input dimension per input channel",
|
56 |
+
)
|
57 |
+
parser.add_argument(
|
58 |
+
"--in-channels",
|
59 |
+
type=int,
|
60 |
+
metavar="N",
|
61 |
+
help="number of encoder input channels",
|
62 |
+
)
|
63 |
+
parser.add_argument(
|
64 |
+
"--conv-enc-config",
|
65 |
+
type=str,
|
66 |
+
metavar="EXPR",
|
67 |
+
help="""
|
68 |
+
an array of tuples each containing the configuration of one conv layer
|
69 |
+
[(out_channels, kernel_size, padding, dropout), ...]
|
70 |
+
""",
|
71 |
+
)
|
72 |
+
|
73 |
+
@classmethod
|
74 |
+
def build_model(cls, args, task):
|
75 |
+
"""Build a new model instance."""
|
76 |
+
conv_enc_config = getattr(args, "conv_enc_config", default_conv_enc_config)
|
77 |
+
encoder = W2lConvGluEncoder(
|
78 |
+
vocab_size=len(task.target_dictionary),
|
79 |
+
input_feat_per_channel=args.input_feat_per_channel,
|
80 |
+
in_channels=args.in_channels,
|
81 |
+
conv_enc_config=eval(conv_enc_config),
|
82 |
+
)
|
83 |
+
return cls(encoder)
|
84 |
+
|
85 |
+
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
86 |
+
lprobs = super().get_normalized_probs(net_output, log_probs, sample)
|
87 |
+
lprobs.batch_first = False
|
88 |
+
return lprobs
|
89 |
+
|
90 |
+
|
91 |
+
class W2lConvGluEncoder(FairseqEncoder):
|
92 |
+
def __init__(
|
93 |
+
self, vocab_size, input_feat_per_channel, in_channels, conv_enc_config
|
94 |
+
):
|
95 |
+
super().__init__(None)
|
96 |
+
|
97 |
+
self.input_dim = input_feat_per_channel
|
98 |
+
if in_channels != 1:
|
99 |
+
raise ValueError("only 1 input channel is currently supported")
|
100 |
+
|
101 |
+
self.conv_layers = nn.ModuleList()
|
102 |
+
self.linear_layers = nn.ModuleList()
|
103 |
+
self.dropouts = []
|
104 |
+
cur_channels = input_feat_per_channel
|
105 |
+
|
106 |
+
for out_channels, kernel_size, padding, dropout in conv_enc_config:
|
107 |
+
layer = nn.Conv1d(cur_channels, out_channels, kernel_size, padding=padding)
|
108 |
+
layer.weight.data.mul_(math.sqrt(3)) # match wav2letter init
|
109 |
+
self.conv_layers.append(nn.utils.weight_norm(layer))
|
110 |
+
self.dropouts.append(
|
111 |
+
FairseqDropout(dropout, module_name=self.__class__.__name__)
|
112 |
+
)
|
113 |
+
if out_channels % 2 != 0:
|
114 |
+
raise ValueError("odd # of out_channels is incompatible with GLU")
|
115 |
+
cur_channels = out_channels // 2 # halved by GLU
|
116 |
+
|
117 |
+
for out_channels in [2 * cur_channels, vocab_size]:
|
118 |
+
layer = nn.Linear(cur_channels, out_channels)
|
119 |
+
layer.weight.data.mul_(math.sqrt(3))
|
120 |
+
self.linear_layers.append(nn.utils.weight_norm(layer))
|
121 |
+
cur_channels = out_channels // 2
|
122 |
+
|
123 |
+
def forward(self, src_tokens, src_lengths, **kwargs):
|
124 |
+
|
125 |
+
"""
|
126 |
+
src_tokens: padded tensor (B, T, C * feat)
|
127 |
+
src_lengths: tensor of original lengths of input utterances (B,)
|
128 |
+
"""
|
129 |
+
B, T, _ = src_tokens.size()
|
130 |
+
x = src_tokens.transpose(1, 2).contiguous() # (B, feat, T) assuming C == 1
|
131 |
+
|
132 |
+
for layer_idx in range(len(self.conv_layers)):
|
133 |
+
x = self.conv_layers[layer_idx](x)
|
134 |
+
x = F.glu(x, dim=1)
|
135 |
+
x = self.dropouts[layer_idx](x)
|
136 |
+
|
137 |
+
x = x.transpose(1, 2).contiguous() # (B, T, 908)
|
138 |
+
x = self.linear_layers[0](x)
|
139 |
+
x = F.glu(x, dim=2)
|
140 |
+
x = self.dropouts[-1](x)
|
141 |
+
x = self.linear_layers[1](x)
|
142 |
+
|
143 |
+
assert x.size(0) == B
|
144 |
+
assert x.size(1) == T
|
145 |
+
|
146 |
+
encoder_out = x.transpose(0, 1) # (T, B, vocab_size)
|
147 |
+
|
148 |
+
# need to debug this -- find a simpler/elegant way in pytorch APIs
|
149 |
+
encoder_padding_mask = (
|
150 |
+
torch.arange(T).view(1, T).expand(B, -1).to(x.device)
|
151 |
+
>= src_lengths.view(B, 1).expand(-1, T)
|
152 |
+
).t() # (B x T) -> (T x B)
|
153 |
+
|
154 |
+
return {
|
155 |
+
"encoder_out": encoder_out, # (T, B, vocab_size)
|
156 |
+
"encoder_padding_mask": encoder_padding_mask, # (T, B)
|
157 |
+
}
|
158 |
+
|
159 |
+
def reorder_encoder_out(self, encoder_out, new_order):
|
160 |
+
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
|
161 |
+
1, new_order
|
162 |
+
)
|
163 |
+
encoder_out["encoder_padding_mask"] = encoder_out[
|
164 |
+
"encoder_padding_mask"
|
165 |
+
].index_select(1, new_order)
|
166 |
+
return encoder_out
|
167 |
+
|
168 |
+
def max_positions(self):
|
169 |
+
"""Maximum input length supported by the encoder."""
|
170 |
+
return (1e6, 1e6) # an arbitrary large number
|
171 |
+
|
172 |
+
|
173 |
+
@register_model_architecture("asr_w2l_conv_glu_encoder", "w2l_conv_glu_enc")
|
174 |
+
def w2l_conv_glu_enc(args):
|
175 |
+
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
|
176 |
+
args.in_channels = getattr(args, "in_channels", 1)
|
177 |
+
args.conv_enc_config = getattr(args, "conv_enc_config", default_conv_enc_config)
|
fairseq/examples/speech_recognition/new/README.md
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Flashlight Decoder
|
2 |
+
|
3 |
+
This script runs decoding for pre-trained speech recognition models.
|
4 |
+
|
5 |
+
## Usage
|
6 |
+
|
7 |
+
Assuming a few variables:
|
8 |
+
|
9 |
+
```bash
|
10 |
+
checkpoint=<path-to-checkpoint>
|
11 |
+
data=<path-to-data-directory>
|
12 |
+
lm_model=<path-to-language-model>
|
13 |
+
lexicon=<path-to-lexicon>
|
14 |
+
```
|
15 |
+
|
16 |
+
Example usage for decoding a fine-tuned Wav2Vec model:
|
17 |
+
|
18 |
+
```bash
|
19 |
+
python $FAIRSEQ_ROOT/examples/speech_recognition/new/infer.py --multirun \
|
20 |
+
task=audio_pretraining \
|
21 |
+
task.data=$data \
|
22 |
+
task.labels=ltr \
|
23 |
+
common_eval.path=$checkpoint \
|
24 |
+
decoding.type=kenlm \
|
25 |
+
decoding.lexicon=$lexicon \
|
26 |
+
decoding.lmpath=$lm_model \
|
27 |
+
dataset.gen_subset=dev_clean,dev_other,test_clean,test_other
|
28 |
+
```
|
29 |
+
|
30 |
+
Example usage for using Ax to sweep WER parameters (requires `pip install hydra-ax-sweeper`):
|
31 |
+
|
32 |
+
```bash
|
33 |
+
python $FAIRSEQ_ROOT/examples/speech_recognition/new/infer.py --multirun \
|
34 |
+
hydra/sweeper=ax \
|
35 |
+
task=audio_pretraining \
|
36 |
+
task.data=$data \
|
37 |
+
task.labels=ltr \
|
38 |
+
common_eval.path=$checkpoint \
|
39 |
+
decoding.type=kenlm \
|
40 |
+
decoding.lexicon=$lexicon \
|
41 |
+
decoding.lmpath=$lm_model \
|
42 |
+
dataset.gen_subset=dev_other
|
43 |
+
```
|
fairseq/examples/speech_recognition/new/__init__.py
ADDED
File without changes
|
fairseq/examples/speech_recognition/new/conf/hydra/sweeper/ax.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package hydra.sweeper
|
2 |
+
_target_: hydra_plugins.hydra_ax_sweeper.ax_sweeper.AxSweeper
|
3 |
+
max_batch_size: null
|
4 |
+
ax_config:
|
5 |
+
max_trials: 128
|
6 |
+
early_stop:
|
7 |
+
minimize: true
|
8 |
+
max_epochs_without_improvement: 10
|
9 |
+
epsilon: 0.025
|
10 |
+
experiment:
|
11 |
+
name: ${dataset.gen_subset}
|
12 |
+
objective_name: wer
|
13 |
+
minimize: true
|
14 |
+
parameter_constraints: null
|
15 |
+
outcome_constraints: null
|
16 |
+
status_quo: null
|
17 |
+
client:
|
18 |
+
verbose_logging: false
|
19 |
+
random_seed: null
|
20 |
+
params:
|
21 |
+
decoding.lmweight:
|
22 |
+
type: range
|
23 |
+
bounds: [0.0, 5.0]
|
24 |
+
decoding.wordscore:
|
25 |
+
type: range
|
26 |
+
bounds: [-5.0, 5.0]
|
27 |
+
decoding.silweight:
|
28 |
+
type: range
|
29 |
+
bounds: [ -8.0, 0.0 ]
|
fairseq/examples/speech_recognition/new/conf/hydra/sweeper/ax_sil.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package hydra.sweeper
|
2 |
+
_target_: hydra_plugins.hydra_ax_sweeper.ax_sweeper.AxSweeper
|
3 |
+
max_batch_size: null
|
4 |
+
ax_config:
|
5 |
+
max_trials: 64
|
6 |
+
early_stop:
|
7 |
+
minimize: true
|
8 |
+
max_epochs_without_improvement: 10
|
9 |
+
epsilon: 0.025
|
10 |
+
experiment:
|
11 |
+
name: ${dataset.gen_subset}
|
12 |
+
objective_name: wer
|
13 |
+
minimize: true
|
14 |
+
parameter_constraints: null
|
15 |
+
outcome_constraints: null
|
16 |
+
status_quo: null
|
17 |
+
client:
|
18 |
+
verbose_logging: false
|
19 |
+
random_seed: null
|
20 |
+
params:
|
21 |
+
decoding.lmweight:
|
22 |
+
type: range
|
23 |
+
bounds: [0.0, 10.0]
|
24 |
+
decoding.wordscore:
|
25 |
+
type: range
|
26 |
+
bounds: [-10.0, 10.0]
|
27 |
+
decoding.silweight:
|
28 |
+
type: range
|
29 |
+
bounds: [ -10.0, 0.0 ]
|
fairseq/examples/speech_recognition/new/conf/infer.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- task: null
|
5 |
+
- model: null
|
6 |
+
|
7 |
+
hydra:
|
8 |
+
run:
|
9 |
+
dir: ${common_eval.results_path}/${dataset.gen_subset}
|
10 |
+
sweep:
|
11 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${common_eval.results_path}
|
12 |
+
subdir: ${dataset.gen_subset}
|
13 |
+
common:
|
14 |
+
user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
|
15 |
+
common_eval:
|
16 |
+
results_path: null
|
17 |
+
path: null
|
18 |
+
post_process: letter
|
19 |
+
quiet: true
|
20 |
+
dataset:
|
21 |
+
max_tokens: 3000000
|
22 |
+
gen_subset: test
|
23 |
+
distributed_training:
|
24 |
+
distributed_world_size: 1
|
25 |
+
decoding:
|
26 |
+
beam: 5
|
27 |
+
type: viterbi
|
fairseq/examples/speech_recognition/new/conf/run_config/fb_slurm_1.yaml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '/'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
- common_eval.path
|
13 |
+
sweep:
|
14 |
+
dir: /checkpoint/abaevski/asr/d2v2/decoding/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
15 |
+
# subdir: ${hydra.job.override_dirname}
|
16 |
+
launcher:
|
17 |
+
cpus_per_task: 16
|
18 |
+
gpus_per_node: 1
|
19 |
+
tasks_per_node: 1
|
20 |
+
nodes: 1
|
21 |
+
partition: devlab,learnlab
|
22 |
+
mem_gb: 100
|
23 |
+
timeout_min: 2000
|
24 |
+
max_num_timeout: 10
|
25 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
26 |
+
submitit_folder: ${hydra.sweep.dir}/%j
|
27 |
+
constraint: volta32gb
|
28 |
+
exclude: learnfair7598
|
fairseq/examples/speech_recognition/new/conf/run_config/fb_slurm_2g.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '/'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
- common_eval.path
|
13 |
+
sweep:
|
14 |
+
dir: /checkpoint/abaevski/asr/d2v2/decoding/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
15 |
+
# subdir: ${hydra.job.override_dirname}
|
16 |
+
launcher:
|
17 |
+
cpus_per_task: 16
|
18 |
+
gpus_per_node: 2
|
19 |
+
tasks_per_node: 2
|
20 |
+
nodes: 1
|
21 |
+
partition: devlab,learnlab
|
22 |
+
mem_gb: 100
|
23 |
+
timeout_min: 2000
|
24 |
+
max_num_timeout: 10
|
25 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
26 |
+
submitit_folder: ${hydra.sweep.dir}/%j
|
27 |
+
constraint: volta32gb
|
fairseq/examples/speech_recognition/new/decoders/__init__.py
ADDED
File without changes
|
fairseq/examples/speech_recognition/new/decoders/base_decoder.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import itertools as it
|
7 |
+
from typing import Any, Dict, List
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from fairseq.data.dictionary import Dictionary
|
11 |
+
from fairseq.models.fairseq_model import FairseqModel
|
12 |
+
|
13 |
+
|
14 |
+
class BaseDecoder:
|
15 |
+
def __init__(self, tgt_dict: Dictionary) -> None:
|
16 |
+
self.tgt_dict = tgt_dict
|
17 |
+
self.vocab_size = len(tgt_dict)
|
18 |
+
|
19 |
+
self.blank = (
|
20 |
+
tgt_dict.index("<ctc_blank>")
|
21 |
+
if "<ctc_blank>" in tgt_dict.indices
|
22 |
+
else tgt_dict.bos()
|
23 |
+
)
|
24 |
+
if "<sep>" in tgt_dict.indices:
|
25 |
+
self.silence = tgt_dict.index("<sep>")
|
26 |
+
elif "|" in tgt_dict.indices:
|
27 |
+
self.silence = tgt_dict.index("|")
|
28 |
+
else:
|
29 |
+
self.silence = tgt_dict.eos()
|
30 |
+
|
31 |
+
def generate(
|
32 |
+
self, models: List[FairseqModel], sample: Dict[str, Any], **unused
|
33 |
+
) -> List[List[Dict[str, torch.LongTensor]]]:
|
34 |
+
encoder_input = {
|
35 |
+
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
|
36 |
+
}
|
37 |
+
emissions = self.get_emissions(models, encoder_input)
|
38 |
+
return self.decode(emissions)
|
39 |
+
|
40 |
+
def get_emissions(
|
41 |
+
self,
|
42 |
+
models: List[FairseqModel],
|
43 |
+
encoder_input: Dict[str, Any],
|
44 |
+
) -> torch.FloatTensor:
|
45 |
+
model = models[0]
|
46 |
+
encoder_out = model(**encoder_input)
|
47 |
+
if hasattr(model, "get_logits"):
|
48 |
+
emissions = model.get_logits(encoder_out)
|
49 |
+
else:
|
50 |
+
emissions = model.get_normalized_probs(encoder_out, log_probs=True)
|
51 |
+
return emissions.transpose(0, 1).float().cpu().contiguous()
|
52 |
+
|
53 |
+
def get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor:
|
54 |
+
idxs = (g[0] for g in it.groupby(idxs))
|
55 |
+
idxs = filter(lambda x: x != self.blank, idxs)
|
56 |
+
return torch.LongTensor(list(idxs))
|
57 |
+
|
58 |
+
def decode(
|
59 |
+
self,
|
60 |
+
emissions: torch.FloatTensor,
|
61 |
+
) -> List[List[Dict[str, torch.LongTensor]]]:
|
62 |
+
raise NotImplementedError
|
fairseq/examples/speech_recognition/new/decoders/decoder.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the MIT license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
from typing import Union
|
9 |
+
|
10 |
+
from fairseq.data.dictionary import Dictionary
|
11 |
+
|
12 |
+
from .decoder_config import DecoderConfig, FlashlightDecoderConfig
|
13 |
+
from .base_decoder import BaseDecoder
|
14 |
+
|
15 |
+
|
16 |
+
def Decoder(
|
17 |
+
cfg: Union[DecoderConfig, FlashlightDecoderConfig], tgt_dict: Dictionary
|
18 |
+
) -> BaseDecoder:
|
19 |
+
|
20 |
+
if cfg.type == "viterbi":
|
21 |
+
from .viterbi_decoder import ViterbiDecoder
|
22 |
+
|
23 |
+
return ViterbiDecoder(tgt_dict)
|
24 |
+
if cfg.type == "kenlm":
|
25 |
+
from .flashlight_decoder import KenLMDecoder
|
26 |
+
|
27 |
+
return KenLMDecoder(cfg, tgt_dict)
|
28 |
+
if cfg.type == "fairseqlm":
|
29 |
+
from .flashlight_decoder import FairseqLMDecoder
|
30 |
+
|
31 |
+
return FairseqLMDecoder(cfg, tgt_dict)
|
32 |
+
raise NotImplementedError(f"Invalid decoder name: {cfg.name}")
|
fairseq/examples/speech_recognition/new/decoders/decoder_config.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
from typing import Optional
|
9 |
+
|
10 |
+
from fairseq.dataclass.configs import FairseqDataclass
|
11 |
+
from fairseq.dataclass.constants import ChoiceEnum
|
12 |
+
from omegaconf import MISSING
|
13 |
+
|
14 |
+
|
15 |
+
DECODER_CHOICES = ChoiceEnum(["viterbi", "kenlm", "fairseqlm"])
|
16 |
+
|
17 |
+
|
18 |
+
@dataclass
|
19 |
+
class DecoderConfig(FairseqDataclass):
|
20 |
+
type: DECODER_CHOICES = field(
|
21 |
+
default="viterbi",
|
22 |
+
metadata={"help": "The type of decoder to use"},
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class FlashlightDecoderConfig(FairseqDataclass):
|
28 |
+
nbest: int = field(
|
29 |
+
default=1,
|
30 |
+
metadata={"help": "Number of decodings to return"},
|
31 |
+
)
|
32 |
+
unitlm: bool = field(
|
33 |
+
default=False,
|
34 |
+
metadata={"help": "If set, use unit language model"},
|
35 |
+
)
|
36 |
+
lmpath: str = field(
|
37 |
+
default=MISSING,
|
38 |
+
metadata={"help": "Language model for KenLM decoder"},
|
39 |
+
)
|
40 |
+
lexicon: Optional[str] = field(
|
41 |
+
default=None,
|
42 |
+
metadata={"help": "Lexicon for Flashlight decoder"},
|
43 |
+
)
|
44 |
+
beam: int = field(
|
45 |
+
default=50,
|
46 |
+
metadata={"help": "Number of beams to use for decoding"},
|
47 |
+
)
|
48 |
+
beamthreshold: float = field(
|
49 |
+
default=50.0,
|
50 |
+
metadata={"help": "Threshold for beam search decoding"},
|
51 |
+
)
|
52 |
+
beamsizetoken: Optional[int] = field(
|
53 |
+
default=None, metadata={"help": "Beam size to use"}
|
54 |
+
)
|
55 |
+
wordscore: float = field(
|
56 |
+
default=-1,
|
57 |
+
metadata={"help": "Word score for KenLM decoder"},
|
58 |
+
)
|
59 |
+
unkweight: float = field(
|
60 |
+
default=-math.inf,
|
61 |
+
metadata={"help": "Unknown weight for KenLM decoder"},
|
62 |
+
)
|
63 |
+
silweight: float = field(
|
64 |
+
default=0,
|
65 |
+
metadata={"help": "Silence weight for KenLM decoder"},
|
66 |
+
)
|
67 |
+
lmweight: float = field(
|
68 |
+
default=2,
|
69 |
+
metadata={"help": "Weight for LM while interpolating score"},
|
70 |
+
)
|
fairseq/examples/speech_recognition/new/decoders/flashlight_decoder.py
ADDED
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the MIT license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
import gc
|
9 |
+
import os.path as osp
|
10 |
+
import warnings
|
11 |
+
from collections import deque, namedtuple
|
12 |
+
from typing import Any, Dict, Tuple
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
from fairseq import tasks
|
17 |
+
from fairseq.data.dictionary import Dictionary
|
18 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
19 |
+
from fairseq.models.fairseq_model import FairseqModel
|
20 |
+
from fairseq.utils import apply_to_sample
|
21 |
+
from omegaconf import open_dict, OmegaConf
|
22 |
+
|
23 |
+
from typing import List
|
24 |
+
|
25 |
+
from .decoder_config import FlashlightDecoderConfig
|
26 |
+
from .base_decoder import BaseDecoder
|
27 |
+
|
28 |
+
try:
|
29 |
+
from flashlight.lib.text.decoder import (
|
30 |
+
LM,
|
31 |
+
CriterionType,
|
32 |
+
DecodeResult,
|
33 |
+
KenLM,
|
34 |
+
LexiconDecoder,
|
35 |
+
LexiconDecoderOptions,
|
36 |
+
LexiconFreeDecoder,
|
37 |
+
LexiconFreeDecoderOptions,
|
38 |
+
LMState,
|
39 |
+
SmearingMode,
|
40 |
+
Trie,
|
41 |
+
)
|
42 |
+
from flashlight.lib.text.dictionary import create_word_dict, load_words
|
43 |
+
from flashlight.lib.text.dictionary import Dictionary as flDictionary
|
44 |
+
except ImportError:
|
45 |
+
warnings.warn(
|
46 |
+
"flashlight python bindings are required to use this functionality. "
|
47 |
+
"Please install from "
|
48 |
+
"https://github.com/facebookresearch/flashlight/tree/master/bindings/python"
|
49 |
+
)
|
50 |
+
LM = object
|
51 |
+
LMState = object
|
52 |
+
|
53 |
+
|
54 |
+
class KenLMDecoder(BaseDecoder):
|
55 |
+
def __init__(self, cfg: FlashlightDecoderConfig, tgt_dict: Dictionary) -> None:
|
56 |
+
super().__init__(tgt_dict)
|
57 |
+
|
58 |
+
self.nbest = cfg.nbest
|
59 |
+
self.unitlm = cfg.unitlm
|
60 |
+
|
61 |
+
if cfg.lexicon:
|
62 |
+
self.lexicon = load_words(cfg.lexicon)
|
63 |
+
self.word_dict = create_word_dict(self.lexicon)
|
64 |
+
self.unk_word = self.word_dict.get_index("<unk>")
|
65 |
+
|
66 |
+
self.lm = KenLM(cfg.lmpath, self.word_dict)
|
67 |
+
self.trie = Trie(self.vocab_size, self.silence)
|
68 |
+
|
69 |
+
start_state = self.lm.start(False)
|
70 |
+
for word, spellings in self.lexicon.items():
|
71 |
+
word_idx = self.word_dict.get_index(word)
|
72 |
+
_, score = self.lm.score(start_state, word_idx)
|
73 |
+
for spelling in spellings:
|
74 |
+
spelling_idxs = [tgt_dict.index(token) for token in spelling]
|
75 |
+
assert (
|
76 |
+
tgt_dict.unk() not in spelling_idxs
|
77 |
+
), f"{word} {spelling} {spelling_idxs}"
|
78 |
+
self.trie.insert(spelling_idxs, word_idx, score)
|
79 |
+
self.trie.smear(SmearingMode.MAX)
|
80 |
+
|
81 |
+
self.decoder_opts = LexiconDecoderOptions(
|
82 |
+
beam_size=cfg.beam,
|
83 |
+
beam_size_token=cfg.beamsizetoken or len(tgt_dict),
|
84 |
+
beam_threshold=cfg.beamthreshold,
|
85 |
+
lm_weight=cfg.lmweight,
|
86 |
+
word_score=cfg.wordscore,
|
87 |
+
unk_score=cfg.unkweight,
|
88 |
+
sil_score=cfg.silweight,
|
89 |
+
log_add=False,
|
90 |
+
criterion_type=CriterionType.CTC,
|
91 |
+
)
|
92 |
+
|
93 |
+
self.decoder = LexiconDecoder(
|
94 |
+
self.decoder_opts,
|
95 |
+
self.trie,
|
96 |
+
self.lm,
|
97 |
+
self.silence,
|
98 |
+
self.blank,
|
99 |
+
self.unk_word,
|
100 |
+
[],
|
101 |
+
self.unitlm,
|
102 |
+
)
|
103 |
+
else:
|
104 |
+
assert self.unitlm, "Lexicon-free decoding requires unit LM"
|
105 |
+
|
106 |
+
self.word_dict = flDictionary()
|
107 |
+
for sym in tgt_dict.symbols:
|
108 |
+
self.word_dict.add_entry(sym, tgt_dict.index(sym))
|
109 |
+
self.lm = KenLM(cfg.lmpath, self.word_dict)
|
110 |
+
self.decoder_opts = LexiconFreeDecoderOptions(
|
111 |
+
beam_size=cfg.beam,
|
112 |
+
beam_size_token=cfg.beamsizetoken or len(tgt_dict),
|
113 |
+
beam_threshold=cfg.beamthreshold,
|
114 |
+
lm_weight=cfg.lmweight,
|
115 |
+
sil_score=cfg.silweight,
|
116 |
+
log_add=False,
|
117 |
+
criterion_type=CriterionType.CTC,
|
118 |
+
)
|
119 |
+
self.decoder = LexiconFreeDecoder(
|
120 |
+
self.decoder_opts, self.lm, self.silence, self.blank, []
|
121 |
+
)
|
122 |
+
|
123 |
+
def get_timesteps(self, token_idxs: List[int]) -> List[int]:
|
124 |
+
"""Returns frame numbers corresponding to every non-blank token.
|
125 |
+
|
126 |
+
Parameters
|
127 |
+
----------
|
128 |
+
token_idxs : List[int]
|
129 |
+
IDs of decoded tokens.
|
130 |
+
|
131 |
+
Returns
|
132 |
+
-------
|
133 |
+
List[int]
|
134 |
+
Frame numbers corresponding to every non-blank token.
|
135 |
+
"""
|
136 |
+
timesteps = []
|
137 |
+
for i, token_idx in enumerate(token_idxs):
|
138 |
+
if token_idx == self.blank:
|
139 |
+
continue
|
140 |
+
if i == 0 or token_idx != token_idxs[i-1]:
|
141 |
+
timesteps.append(i)
|
142 |
+
return timesteps
|
143 |
+
|
144 |
+
def decode(
|
145 |
+
self,
|
146 |
+
emissions: torch.FloatTensor,
|
147 |
+
) -> List[List[Dict[str, torch.LongTensor]]]:
|
148 |
+
B, T, N = emissions.size()
|
149 |
+
hypos = []
|
150 |
+
for b in range(B):
|
151 |
+
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
|
152 |
+
results = self.decoder.decode(emissions_ptr, T, N)
|
153 |
+
|
154 |
+
nbest_results = results[: self.nbest]
|
155 |
+
hypos.append(
|
156 |
+
[
|
157 |
+
{
|
158 |
+
"tokens": self.get_tokens(result.tokens),
|
159 |
+
"score": result.score,
|
160 |
+
"timesteps": self.get_timesteps(result.tokens),
|
161 |
+
"words": [
|
162 |
+
self.word_dict.get_entry(x) for x in result.words if x >= 0
|
163 |
+
],
|
164 |
+
}
|
165 |
+
for result in nbest_results
|
166 |
+
]
|
167 |
+
)
|
168 |
+
return hypos
|
169 |
+
|
170 |
+
|
171 |
+
FairseqLMState = namedtuple(
|
172 |
+
"FairseqLMState",
|
173 |
+
[
|
174 |
+
"prefix",
|
175 |
+
"incremental_state",
|
176 |
+
"probs",
|
177 |
+
],
|
178 |
+
)
|
179 |
+
|
180 |
+
|
181 |
+
class FairseqLM(LM):
|
182 |
+
def __init__(self, dictionary: Dictionary, model: FairseqModel) -> None:
|
183 |
+
super().__init__()
|
184 |
+
|
185 |
+
self.dictionary = dictionary
|
186 |
+
self.model = model
|
187 |
+
self.unk = self.dictionary.unk()
|
188 |
+
|
189 |
+
self.save_incremental = False # this currently does not work properly
|
190 |
+
self.max_cache = 20_000
|
191 |
+
|
192 |
+
if torch.cuda.is_available():
|
193 |
+
model.cuda()
|
194 |
+
model.eval()
|
195 |
+
model.make_generation_fast_()
|
196 |
+
|
197 |
+
self.states = {}
|
198 |
+
self.stateq = deque()
|
199 |
+
|
200 |
+
def start(self, start_with_nothing: bool) -> LMState:
|
201 |
+
state = LMState()
|
202 |
+
prefix = torch.LongTensor([[self.dictionary.eos()]])
|
203 |
+
incremental_state = {} if self.save_incremental else None
|
204 |
+
with torch.no_grad():
|
205 |
+
res = self.model(prefix.cuda(), incremental_state=incremental_state)
|
206 |
+
probs = self.model.get_normalized_probs(res, log_probs=True, sample=None)
|
207 |
+
|
208 |
+
if incremental_state is not None:
|
209 |
+
incremental_state = apply_to_sample(lambda x: x.cpu(), incremental_state)
|
210 |
+
self.states[state] = FairseqLMState(
|
211 |
+
prefix.numpy(), incremental_state, probs[0, -1].cpu().numpy()
|
212 |
+
)
|
213 |
+
self.stateq.append(state)
|
214 |
+
|
215 |
+
return state
|
216 |
+
|
217 |
+
def score(
|
218 |
+
self,
|
219 |
+
state: LMState,
|
220 |
+
token_index: int,
|
221 |
+
no_cache: bool = False,
|
222 |
+
) -> Tuple[LMState, int]:
|
223 |
+
"""
|
224 |
+
Evaluate language model based on the current lm state and new word
|
225 |
+
Parameters:
|
226 |
+
-----------
|
227 |
+
state: current lm state
|
228 |
+
token_index: index of the word
|
229 |
+
(can be lexicon index then you should store inside LM the
|
230 |
+
mapping between indices of lexicon and lm, or lm index of a word)
|
231 |
+
Returns:
|
232 |
+
--------
|
233 |
+
(LMState, float): pair of (new state, score for the current word)
|
234 |
+
"""
|
235 |
+
curr_state = self.states[state]
|
236 |
+
|
237 |
+
def trim_cache(targ_size: int) -> None:
|
238 |
+
while len(self.stateq) > targ_size:
|
239 |
+
rem_k = self.stateq.popleft()
|
240 |
+
rem_st = self.states[rem_k]
|
241 |
+
rem_st = FairseqLMState(rem_st.prefix, None, None)
|
242 |
+
self.states[rem_k] = rem_st
|
243 |
+
|
244 |
+
if curr_state.probs is None:
|
245 |
+
new_incremental_state = (
|
246 |
+
curr_state.incremental_state.copy()
|
247 |
+
if curr_state.incremental_state is not None
|
248 |
+
else None
|
249 |
+
)
|
250 |
+
with torch.no_grad():
|
251 |
+
if new_incremental_state is not None:
|
252 |
+
new_incremental_state = apply_to_sample(
|
253 |
+
lambda x: x.cuda(), new_incremental_state
|
254 |
+
)
|
255 |
+
elif self.save_incremental:
|
256 |
+
new_incremental_state = {}
|
257 |
+
|
258 |
+
res = self.model(
|
259 |
+
torch.from_numpy(curr_state.prefix).cuda(),
|
260 |
+
incremental_state=new_incremental_state,
|
261 |
+
)
|
262 |
+
probs = self.model.get_normalized_probs(
|
263 |
+
res, log_probs=True, sample=None
|
264 |
+
)
|
265 |
+
|
266 |
+
if new_incremental_state is not None:
|
267 |
+
new_incremental_state = apply_to_sample(
|
268 |
+
lambda x: x.cpu(), new_incremental_state
|
269 |
+
)
|
270 |
+
|
271 |
+
curr_state = FairseqLMState(
|
272 |
+
curr_state.prefix, new_incremental_state, probs[0, -1].cpu().numpy()
|
273 |
+
)
|
274 |
+
|
275 |
+
if not no_cache:
|
276 |
+
self.states[state] = curr_state
|
277 |
+
self.stateq.append(state)
|
278 |
+
|
279 |
+
score = curr_state.probs[token_index].item()
|
280 |
+
|
281 |
+
trim_cache(self.max_cache)
|
282 |
+
|
283 |
+
outstate = state.child(token_index)
|
284 |
+
if outstate not in self.states and not no_cache:
|
285 |
+
prefix = np.concatenate(
|
286 |
+
[curr_state.prefix, torch.LongTensor([[token_index]])], -1
|
287 |
+
)
|
288 |
+
incr_state = curr_state.incremental_state
|
289 |
+
|
290 |
+
self.states[outstate] = FairseqLMState(prefix, incr_state, None)
|
291 |
+
|
292 |
+
if token_index == self.unk:
|
293 |
+
score = float("-inf")
|
294 |
+
|
295 |
+
return outstate, score
|
296 |
+
|
297 |
+
def finish(self, state: LMState) -> Tuple[LMState, int]:
|
298 |
+
"""
|
299 |
+
Evaluate eos for language model based on the current lm state
|
300 |
+
Returns:
|
301 |
+
--------
|
302 |
+
(LMState, float): pair of (new state, score for the current word)
|
303 |
+
"""
|
304 |
+
return self.score(state, self.dictionary.eos())
|
305 |
+
|
306 |
+
def empty_cache(self) -> None:
|
307 |
+
self.states = {}
|
308 |
+
self.stateq = deque()
|
309 |
+
gc.collect()
|
310 |
+
|
311 |
+
|
312 |
+
class FairseqLMDecoder(BaseDecoder):
|
313 |
+
def __init__(self, cfg: FlashlightDecoderConfig, tgt_dict: Dictionary) -> None:
|
314 |
+
super().__init__(tgt_dict)
|
315 |
+
|
316 |
+
self.nbest = cfg.nbest
|
317 |
+
self.unitlm = cfg.unitlm
|
318 |
+
|
319 |
+
self.lexicon = load_words(cfg.lexicon) if cfg.lexicon else None
|
320 |
+
self.idx_to_wrd = {}
|
321 |
+
|
322 |
+
checkpoint = torch.load(cfg.lmpath, map_location="cpu")
|
323 |
+
|
324 |
+
if "cfg" in checkpoint and checkpoint["cfg"] is not None:
|
325 |
+
lm_args = checkpoint["cfg"]
|
326 |
+
else:
|
327 |
+
lm_args = convert_namespace_to_omegaconf(checkpoint["args"])
|
328 |
+
|
329 |
+
if not OmegaConf.is_dict(lm_args):
|
330 |
+
lm_args = OmegaConf.create(lm_args)
|
331 |
+
|
332 |
+
with open_dict(lm_args.task):
|
333 |
+
lm_args.task.data = osp.dirname(cfg.lmpath)
|
334 |
+
|
335 |
+
task = tasks.setup_task(lm_args.task)
|
336 |
+
model = task.build_model(lm_args.model)
|
337 |
+
model.load_state_dict(checkpoint["model"], strict=False)
|
338 |
+
|
339 |
+
self.trie = Trie(self.vocab_size, self.silence)
|
340 |
+
|
341 |
+
self.word_dict = task.dictionary
|
342 |
+
self.unk_word = self.word_dict.unk()
|
343 |
+
self.lm = FairseqLM(self.word_dict, model)
|
344 |
+
|
345 |
+
if self.lexicon:
|
346 |
+
start_state = self.lm.start(False)
|
347 |
+
for i, (word, spellings) in enumerate(self.lexicon.items()):
|
348 |
+
if self.unitlm:
|
349 |
+
word_idx = i
|
350 |
+
self.idx_to_wrd[i] = word
|
351 |
+
score = 0
|
352 |
+
else:
|
353 |
+
word_idx = self.word_dict.index(word)
|
354 |
+
_, score = self.lm.score(start_state, word_idx, no_cache=True)
|
355 |
+
|
356 |
+
for spelling in spellings:
|
357 |
+
spelling_idxs = [tgt_dict.index(token) for token in spelling]
|
358 |
+
assert (
|
359 |
+
tgt_dict.unk() not in spelling_idxs
|
360 |
+
), f"{spelling} {spelling_idxs}"
|
361 |
+
self.trie.insert(spelling_idxs, word_idx, score)
|
362 |
+
self.trie.smear(SmearingMode.MAX)
|
363 |
+
|
364 |
+
self.decoder_opts = LexiconDecoderOptions(
|
365 |
+
beam_size=cfg.beam,
|
366 |
+
beam_size_token=cfg.beamsizetoken or len(tgt_dict),
|
367 |
+
beam_threshold=cfg.beamthreshold,
|
368 |
+
lm_weight=cfg.lmweight,
|
369 |
+
word_score=cfg.wordscore,
|
370 |
+
unk_score=cfg.unkweight,
|
371 |
+
sil_score=cfg.silweight,
|
372 |
+
log_add=False,
|
373 |
+
criterion_type=CriterionType.CTC,
|
374 |
+
)
|
375 |
+
|
376 |
+
self.decoder = LexiconDecoder(
|
377 |
+
self.decoder_opts,
|
378 |
+
self.trie,
|
379 |
+
self.lm,
|
380 |
+
self.silence,
|
381 |
+
self.blank,
|
382 |
+
self.unk_word,
|
383 |
+
[],
|
384 |
+
self.unitlm,
|
385 |
+
)
|
386 |
+
else:
|
387 |
+
assert self.unitlm, "Lexicon-free decoding requires unit LM"
|
388 |
+
|
389 |
+
d = {w: [[w]] for w in tgt_dict.symbols}
|
390 |
+
self.word_dict = create_word_dict(d)
|
391 |
+
self.lm = KenLM(cfg.lmpath, self.word_dict)
|
392 |
+
self.decoder_opts = LexiconFreeDecoderOptions(
|
393 |
+
beam_size=cfg.beam,
|
394 |
+
beam_size_token=cfg.beamsizetoken or len(tgt_dict),
|
395 |
+
beam_threshold=cfg.beamthreshold,
|
396 |
+
lm_weight=cfg.lmweight,
|
397 |
+
sil_score=cfg.silweight,
|
398 |
+
log_add=False,
|
399 |
+
criterion_type=CriterionType.CTC,
|
400 |
+
)
|
401 |
+
self.decoder = LexiconFreeDecoder(
|
402 |
+
self.decoder_opts, self.lm, self.silence, self.blank, []
|
403 |
+
)
|
404 |
+
|
405 |
+
def decode(
|
406 |
+
self,
|
407 |
+
emissions: torch.FloatTensor,
|
408 |
+
) -> List[List[Dict[str, torch.LongTensor]]]:
|
409 |
+
B, T, N = emissions.size()
|
410 |
+
hypos = []
|
411 |
+
|
412 |
+
def make_hypo(result: DecodeResult) -> Dict[str, Any]:
|
413 |
+
hypo = {
|
414 |
+
"tokens": self.get_tokens(result.tokens),
|
415 |
+
"score": result.score,
|
416 |
+
}
|
417 |
+
if self.lexicon:
|
418 |
+
hypo["words"] = [
|
419 |
+
self.idx_to_wrd[x] if self.unitlm else self.word_dict[x]
|
420 |
+
for x in result.words
|
421 |
+
if x >= 0
|
422 |
+
]
|
423 |
+
return hypo
|
424 |
+
|
425 |
+
for b in range(B):
|
426 |
+
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
|
427 |
+
results = self.decoder.decode(emissions_ptr, T, N)
|
428 |
+
|
429 |
+
nbest_results = results[: self.nbest]
|
430 |
+
hypos.append([make_hypo(result) for result in nbest_results])
|
431 |
+
self.lm.empty_cache()
|
432 |
+
|
433 |
+
return hypos
|
fairseq/examples/speech_recognition/new/decoders/viterbi_decoder.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the MIT license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from typing import List, Dict
|
11 |
+
|
12 |
+
from .base_decoder import BaseDecoder
|
13 |
+
|
14 |
+
|
15 |
+
class ViterbiDecoder(BaseDecoder):
|
16 |
+
def decode(
|
17 |
+
self,
|
18 |
+
emissions: torch.FloatTensor,
|
19 |
+
) -> List[List[Dict[str, torch.LongTensor]]]:
|
20 |
+
def get_pred(e):
|
21 |
+
score = e.log_softmax(dim=-1).max(dim=-1)[0].sum()
|
22 |
+
toks = e.argmax(dim=-1).unique_consecutive()
|
23 |
+
return {"tokens":toks[toks != self.blank], "score":score}
|
24 |
+
return [[get_pred(x)] for x in emissions]
|
fairseq/examples/speech_recognition/new/infer.py
ADDED
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python -u
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import ast
|
8 |
+
import hashlib
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import shutil
|
12 |
+
import sys
|
13 |
+
import re
|
14 |
+
from dataclasses import dataclass, field, is_dataclass
|
15 |
+
from pathlib import Path
|
16 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import editdistance
|
19 |
+
import torch
|
20 |
+
import torch.distributed as dist
|
21 |
+
from examples.speech_recognition.new.decoders.decoder_config import (
|
22 |
+
DecoderConfig,
|
23 |
+
FlashlightDecoderConfig,
|
24 |
+
)
|
25 |
+
from examples.speech_recognition.new.decoders.decoder import Decoder
|
26 |
+
from fairseq import checkpoint_utils, distributed_utils, progress_bar, tasks, utils
|
27 |
+
from fairseq.data.data_utils import post_process
|
28 |
+
from fairseq.dataclass.configs import (
|
29 |
+
CheckpointConfig,
|
30 |
+
CommonConfig,
|
31 |
+
CommonEvalConfig,
|
32 |
+
DatasetConfig,
|
33 |
+
DistributedTrainingConfig,
|
34 |
+
FairseqDataclass,
|
35 |
+
)
|
36 |
+
from fairseq.logging.meters import StopwatchMeter, TimeMeter
|
37 |
+
from fairseq.logging.progress_bar import BaseProgressBar
|
38 |
+
from fairseq.models.fairseq_model import FairseqModel
|
39 |
+
from omegaconf import OmegaConf
|
40 |
+
|
41 |
+
import hydra
|
42 |
+
from hydra.core.config_store import ConfigStore
|
43 |
+
|
44 |
+
logging.root.setLevel(logging.INFO)
|
45 |
+
logging.basicConfig(level=logging.INFO)
|
46 |
+
logger = logging.getLogger(__name__)
|
47 |
+
|
48 |
+
config_path = Path(__file__).resolve().parent / "conf"
|
49 |
+
|
50 |
+
|
51 |
+
@dataclass
|
52 |
+
class DecodingConfig(DecoderConfig, FlashlightDecoderConfig):
|
53 |
+
unique_wer_file: bool = field(
|
54 |
+
default=False,
|
55 |
+
metadata={"help": "If set, use a unique file for storing WER"},
|
56 |
+
)
|
57 |
+
results_path: Optional[str] = field(
|
58 |
+
default=None,
|
59 |
+
metadata={
|
60 |
+
"help": "If set, write hypothesis and reference sentences into this directory"
|
61 |
+
},
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
@dataclass
|
66 |
+
class InferConfig(FairseqDataclass):
|
67 |
+
task: Any = None
|
68 |
+
decoding: DecodingConfig = DecodingConfig()
|
69 |
+
common: CommonConfig = CommonConfig()
|
70 |
+
common_eval: CommonEvalConfig = CommonEvalConfig()
|
71 |
+
checkpoint: CheckpointConfig = CheckpointConfig()
|
72 |
+
distributed_training: DistributedTrainingConfig = DistributedTrainingConfig()
|
73 |
+
dataset: DatasetConfig = DatasetConfig()
|
74 |
+
is_ax: bool = field(
|
75 |
+
default=False,
|
76 |
+
metadata={
|
77 |
+
"help": "if true, assumes we are using ax for tuning and returns a tuple for ax to consume"
|
78 |
+
},
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
def reset_logging():
|
83 |
+
root = logging.getLogger()
|
84 |
+
for handler in root.handlers:
|
85 |
+
root.removeHandler(handler)
|
86 |
+
root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper())
|
87 |
+
handler = logging.StreamHandler(sys.stdout)
|
88 |
+
handler.setFormatter(
|
89 |
+
logging.Formatter(
|
90 |
+
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
91 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
92 |
+
)
|
93 |
+
)
|
94 |
+
root.addHandler(handler)
|
95 |
+
|
96 |
+
|
97 |
+
class InferenceProcessor:
|
98 |
+
cfg: InferConfig
|
99 |
+
|
100 |
+
def __init__(self, cfg: InferConfig) -> None:
|
101 |
+
self.cfg = cfg
|
102 |
+
self.task = tasks.setup_task(cfg.task)
|
103 |
+
|
104 |
+
models, saved_cfg = self.load_model_ensemble()
|
105 |
+
|
106 |
+
### LOAD ADAPTER ####
|
107 |
+
ckpt_obj = checkpoint_utils.load_checkpoint_to_cpu(self.cfg.common_eval.path)
|
108 |
+
if "adapter" in ckpt_obj:
|
109 |
+
target_lang = self.cfg.dataset.gen_subset.split(":")[0]
|
110 |
+
assert target_lang in ckpt_obj["adapter"]
|
111 |
+
|
112 |
+
logger.info(f">>> LOADING ADAPTER: {target_lang}")
|
113 |
+
ft_obj = ckpt_obj["adapter"][target_lang]
|
114 |
+
ft_model = ft_obj["model"]
|
115 |
+
cdevice = models[0].w2v_encoder.proj.weight.device
|
116 |
+
cdtype = models[0].w2v_encoder.proj.weight.dtype
|
117 |
+
ft_proj_out, ft_proj_in = ft_model["w2v_encoder.proj.weight"].shape
|
118 |
+
ft_proj = torch.nn.Linear(ft_proj_in, ft_proj_out, bias=True)
|
119 |
+
ft_proj.to(device=cdevice, dtype=cdtype)
|
120 |
+
models[0].w2v_encoder.proj = ft_proj
|
121 |
+
with torch.no_grad():
|
122 |
+
for kk, vv in models[0].named_parameters():
|
123 |
+
if kk in ft_model:
|
124 |
+
vv.copy_(ft_model[kk])
|
125 |
+
self.task.load_state_dict(ft_obj["task_state"])
|
126 |
+
# overwrite gen_subset with master config
|
127 |
+
self.cfg.dataset.gen_subset = re.sub('^[\w-]+:', saved_cfg['task']['multi_corpus_keys']+":", self.cfg.dataset.gen_subset)
|
128 |
+
self.models = models
|
129 |
+
self.saved_cfg = saved_cfg
|
130 |
+
self.tgt_dict = self.task.target_dictionary
|
131 |
+
|
132 |
+
self.task.load_dataset(
|
133 |
+
self.cfg.dataset.gen_subset,
|
134 |
+
task_cfg=saved_cfg.task,
|
135 |
+
)
|
136 |
+
self.generator = Decoder(cfg.decoding, self.tgt_dict)
|
137 |
+
self.gen_timer = StopwatchMeter()
|
138 |
+
self.wps_meter = TimeMeter()
|
139 |
+
self.num_sentences = 0
|
140 |
+
self.total_errors = 0
|
141 |
+
self.total_length = 0
|
142 |
+
|
143 |
+
self.hypo_words_file = None
|
144 |
+
self.hypo_units_file = None
|
145 |
+
self.ref_words_file = None
|
146 |
+
self.ref_units_file = None
|
147 |
+
self.score_file = None
|
148 |
+
|
149 |
+
self.progress_bar = self.build_progress_bar()
|
150 |
+
|
151 |
+
def __enter__(self) -> "InferenceProcessor":
|
152 |
+
if self.cfg.decoding.results_path is not None:
|
153 |
+
self.hypo_words_file = self.get_res_file("hypo.word")
|
154 |
+
self.hypo_units_file = self.get_res_file("hypo.units")
|
155 |
+
self.ref_words_file = self.get_res_file("ref.word")
|
156 |
+
self.ref_units_file = self.get_res_file("ref.units")
|
157 |
+
self.score_file = self.get_res_file("asr_score")
|
158 |
+
return self
|
159 |
+
|
160 |
+
def __exit__(self, *exc) -> bool:
|
161 |
+
if self.cfg.decoding.results_path is not None:
|
162 |
+
self.hypo_words_file.close()
|
163 |
+
self.hypo_units_file.close()
|
164 |
+
self.ref_words_file.close()
|
165 |
+
self.ref_units_file.close()
|
166 |
+
self.score_file.close()
|
167 |
+
return False
|
168 |
+
|
169 |
+
def __iter__(self) -> Any:
|
170 |
+
for sample in self.progress_bar:
|
171 |
+
if not self.cfg.common.cpu:
|
172 |
+
sample = utils.move_to_cuda(sample)
|
173 |
+
|
174 |
+
# Happens on the last batch.
|
175 |
+
if "net_input" not in sample:
|
176 |
+
continue
|
177 |
+
yield sample
|
178 |
+
|
179 |
+
def log(self, *args, **kwargs):
|
180 |
+
self.progress_bar.log(*args, **kwargs)
|
181 |
+
|
182 |
+
def print(self, *args, **kwargs):
|
183 |
+
self.progress_bar.print(*args, **kwargs)
|
184 |
+
|
185 |
+
def get_res_file(self, fname: str) -> None:
|
186 |
+
fname = os.path.join(self.cfg.decoding.results_path, fname)
|
187 |
+
if self.data_parallel_world_size > 1:
|
188 |
+
fname = f"{fname}.{self.data_parallel_rank}"
|
189 |
+
return open(fname, "w", buffering=1)
|
190 |
+
|
191 |
+
def merge_shards(self) -> None:
|
192 |
+
"""Merges all shard files into shard 0, then removes shard suffix."""
|
193 |
+
|
194 |
+
shard_id = self.data_parallel_rank
|
195 |
+
num_shards = self.data_parallel_world_size
|
196 |
+
|
197 |
+
if self.data_parallel_world_size > 1:
|
198 |
+
|
199 |
+
def merge_shards_with_root(fname: str) -> None:
|
200 |
+
fname = os.path.join(self.cfg.decoding.results_path, fname)
|
201 |
+
logger.info("Merging %s on shard %d", fname, shard_id)
|
202 |
+
base_fpath = Path(f"{fname}.0")
|
203 |
+
with open(base_fpath, "a") as out_file:
|
204 |
+
for s in range(1, num_shards):
|
205 |
+
shard_fpath = Path(f"{fname}.{s}")
|
206 |
+
with open(shard_fpath, "r") as in_file:
|
207 |
+
for line in in_file:
|
208 |
+
out_file.write(line)
|
209 |
+
shard_fpath.unlink()
|
210 |
+
shutil.move(f"{fname}.0", fname)
|
211 |
+
|
212 |
+
dist.barrier() # ensure all shards finished writing
|
213 |
+
if shard_id == (0 % num_shards):
|
214 |
+
merge_shards_with_root("hypo.word")
|
215 |
+
if shard_id == (1 % num_shards):
|
216 |
+
merge_shards_with_root("hypo.units")
|
217 |
+
if shard_id == (2 % num_shards):
|
218 |
+
merge_shards_with_root("ref.word")
|
219 |
+
if shard_id == (3 % num_shards):
|
220 |
+
merge_shards_with_root("ref.units")
|
221 |
+
dist.barrier()
|
222 |
+
|
223 |
+
def optimize_model(self, model: FairseqModel) -> None:
|
224 |
+
model.make_generation_fast_()
|
225 |
+
if self.cfg.common.fp16:
|
226 |
+
model.half()
|
227 |
+
if not self.cfg.common.cpu:
|
228 |
+
model.cuda()
|
229 |
+
|
230 |
+
def load_model_ensemble(self) -> Tuple[List[FairseqModel], FairseqDataclass]:
|
231 |
+
arg_overrides = ast.literal_eval(self.cfg.common_eval.model_overrides)
|
232 |
+
models, saved_cfg = checkpoint_utils.load_model_ensemble(
|
233 |
+
utils.split_paths(self.cfg.common_eval.path, separator="\\"),
|
234 |
+
arg_overrides=arg_overrides,
|
235 |
+
task=self.task,
|
236 |
+
suffix=self.cfg.checkpoint.checkpoint_suffix,
|
237 |
+
strict=(self.cfg.checkpoint.checkpoint_shard_count == 1),
|
238 |
+
num_shards=self.cfg.checkpoint.checkpoint_shard_count,
|
239 |
+
)
|
240 |
+
for model in models:
|
241 |
+
self.optimize_model(model)
|
242 |
+
return models, saved_cfg
|
243 |
+
|
244 |
+
def get_dataset_itr(self, disable_iterator_cache: bool = False) -> None:
|
245 |
+
return self.task.get_batch_iterator(
|
246 |
+
dataset=self.task.dataset(self.cfg.dataset.gen_subset),
|
247 |
+
max_tokens=self.cfg.dataset.max_tokens,
|
248 |
+
max_sentences=self.cfg.dataset.batch_size,
|
249 |
+
max_positions=(sys.maxsize, sys.maxsize),
|
250 |
+
ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test,
|
251 |
+
required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple,
|
252 |
+
seed=self.cfg.common.seed,
|
253 |
+
num_shards=self.data_parallel_world_size,
|
254 |
+
shard_id=self.data_parallel_rank,
|
255 |
+
num_workers=self.cfg.dataset.num_workers,
|
256 |
+
data_buffer_size=self.cfg.dataset.data_buffer_size,
|
257 |
+
disable_iterator_cache=disable_iterator_cache,
|
258 |
+
).next_epoch_itr(shuffle=False)
|
259 |
+
|
260 |
+
def build_progress_bar(
|
261 |
+
self,
|
262 |
+
epoch: Optional[int] = None,
|
263 |
+
prefix: Optional[str] = None,
|
264 |
+
default_log_format: str = "tqdm",
|
265 |
+
) -> BaseProgressBar:
|
266 |
+
return progress_bar.progress_bar(
|
267 |
+
iterator=self.get_dataset_itr(),
|
268 |
+
log_format=self.cfg.common.log_format,
|
269 |
+
log_interval=self.cfg.common.log_interval,
|
270 |
+
epoch=epoch,
|
271 |
+
prefix=prefix,
|
272 |
+
tensorboard_logdir=self.cfg.common.tensorboard_logdir,
|
273 |
+
default_log_format=default_log_format,
|
274 |
+
)
|
275 |
+
|
276 |
+
@property
|
277 |
+
def data_parallel_world_size(self):
|
278 |
+
if self.cfg.distributed_training.distributed_world_size == 1:
|
279 |
+
return 1
|
280 |
+
return distributed_utils.get_data_parallel_world_size()
|
281 |
+
|
282 |
+
@property
|
283 |
+
def data_parallel_rank(self):
|
284 |
+
if self.cfg.distributed_training.distributed_world_size == 1:
|
285 |
+
return 0
|
286 |
+
return distributed_utils.get_data_parallel_rank()
|
287 |
+
|
288 |
+
def process_sentence(
|
289 |
+
self,
|
290 |
+
sample: Dict[str, Any],
|
291 |
+
hypo: Dict[str, Any],
|
292 |
+
sid: int,
|
293 |
+
batch_id: int,
|
294 |
+
) -> Tuple[int, int]:
|
295 |
+
speaker = None # Speaker can't be parsed from dataset.
|
296 |
+
if "target_label" in sample:
|
297 |
+
toks = sample["target_label"]
|
298 |
+
else:
|
299 |
+
toks = sample["target"]
|
300 |
+
toks = toks[batch_id, :]
|
301 |
+
|
302 |
+
# Processes hypothesis.
|
303 |
+
hyp_pieces = self.tgt_dict.string(hypo["tokens"].int().cpu())
|
304 |
+
if "words" in hypo:
|
305 |
+
hyp_words = " ".join(hypo["words"])
|
306 |
+
else:
|
307 |
+
hyp_words = post_process(hyp_pieces, self.cfg.common_eval.post_process)
|
308 |
+
|
309 |
+
# Processes target.
|
310 |
+
target_tokens = utils.strip_pad(toks, self.tgt_dict.pad())
|
311 |
+
tgt_pieces = self.tgt_dict.string(target_tokens.int().cpu())
|
312 |
+
tgt_words = post_process(tgt_pieces, self.cfg.common_eval.post_process)
|
313 |
+
|
314 |
+
if self.cfg.decoding.results_path is not None:
|
315 |
+
print(f"{hyp_pieces} ({speaker}-{sid})", file=self.hypo_units_file)
|
316 |
+
print(f"{hyp_words} ({speaker}-{sid})", file=self.hypo_words_file)
|
317 |
+
print(f"{tgt_pieces} ({speaker}-{sid})", file=self.ref_units_file)
|
318 |
+
print(f"{tgt_words} ({speaker}-{sid})", file=self.ref_words_file)
|
319 |
+
print(f"{hypo['score'].item()} ({speaker}-{sid})", file=self.score_file)
|
320 |
+
|
321 |
+
if not self.cfg.common_eval.quiet:
|
322 |
+
logger.info(f"HYPO: {hyp_words}")
|
323 |
+
logger.info(f"REF: {tgt_words}")
|
324 |
+
logger.info("---------------------")
|
325 |
+
|
326 |
+
hyp_words, tgt_words = hyp_words.split(), tgt_words.split()
|
327 |
+
|
328 |
+
return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
|
329 |
+
|
330 |
+
def process_sample(self, sample: Dict[str, Any]) -> None:
|
331 |
+
self.gen_timer.start()
|
332 |
+
hypos = self.task.inference_step(
|
333 |
+
generator=self.generator,
|
334 |
+
models=self.models,
|
335 |
+
sample=sample,
|
336 |
+
)
|
337 |
+
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
|
338 |
+
self.gen_timer.stop(num_generated_tokens)
|
339 |
+
self.wps_meter.update(num_generated_tokens)
|
340 |
+
|
341 |
+
for batch_id, sample_id in enumerate(sample["id"].tolist()):
|
342 |
+
errs, length = self.process_sentence(
|
343 |
+
sample=sample,
|
344 |
+
sid=sample_id,
|
345 |
+
batch_id=batch_id,
|
346 |
+
hypo=hypos[batch_id][0],
|
347 |
+
)
|
348 |
+
self.total_errors += errs
|
349 |
+
self.total_length += length
|
350 |
+
|
351 |
+
self.log({"wps": round(self.wps_meter.avg)})
|
352 |
+
if "nsentences" in sample:
|
353 |
+
self.num_sentences += sample["nsentences"]
|
354 |
+
else:
|
355 |
+
self.num_sentences += sample["id"].numel()
|
356 |
+
|
357 |
+
def log_generation_time(self) -> None:
|
358 |
+
logger.info(
|
359 |
+
"Processed %d sentences (%d tokens) in %.1fs %.2f "
|
360 |
+
"sentences per second, %.2f tokens per second)",
|
361 |
+
self.num_sentences,
|
362 |
+
self.gen_timer.n,
|
363 |
+
self.gen_timer.sum,
|
364 |
+
self.num_sentences / (self.gen_timer.sum + 1e-6),
|
365 |
+
1.0 / (self.gen_timer.avg + 1e-6),
|
366 |
+
)
|
367 |
+
|
368 |
+
|
369 |
+
def parse_wer(wer_file: Path) -> float:
|
370 |
+
with open(wer_file, "r") as f:
|
371 |
+
return float(f.readline().strip().split(" ")[1])
|
372 |
+
|
373 |
+
|
374 |
+
def get_wer_file(cfg: InferConfig) -> Path:
|
375 |
+
"""Hashes the decoding parameters to a unique file ID."""
|
376 |
+
base_path = "wer"
|
377 |
+
if cfg.decoding.results_path is not None:
|
378 |
+
base_path = os.path.join(cfg.decoding.results_path, base_path)
|
379 |
+
|
380 |
+
if cfg.decoding.unique_wer_file:
|
381 |
+
yaml_str = OmegaConf.to_yaml(cfg.decoding)
|
382 |
+
fid = int(hashlib.md5(yaml_str.encode("utf-8")).hexdigest(), 16)
|
383 |
+
return Path(f"{base_path}.{fid % 1000000}")
|
384 |
+
else:
|
385 |
+
return Path(base_path)
|
386 |
+
|
387 |
+
|
388 |
+
def main(cfg: InferConfig) -> float:
|
389 |
+
"""Entry point for main processing logic.
|
390 |
+
|
391 |
+
Args:
|
392 |
+
cfg: The inferance configuration to use.
|
393 |
+
wer: Optional shared memory pointer for returning the WER. If not None,
|
394 |
+
the final WER value will be written here instead of being returned.
|
395 |
+
|
396 |
+
Returns:
|
397 |
+
The final WER if `wer` is None, otherwise None.
|
398 |
+
"""
|
399 |
+
|
400 |
+
yaml_str, wer_file = OmegaConf.to_yaml(cfg.decoding), get_wer_file(cfg)
|
401 |
+
|
402 |
+
# Validates the provided configuration.
|
403 |
+
if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
|
404 |
+
cfg.dataset.max_tokens = 4000000
|
405 |
+
if not cfg.common.cpu and not torch.cuda.is_available():
|
406 |
+
raise ValueError("CUDA not found; set `cpu=True` to run without CUDA")
|
407 |
+
|
408 |
+
logger.info(cfg.common_eval.path)
|
409 |
+
|
410 |
+
with InferenceProcessor(cfg) as processor:
|
411 |
+
for sample in processor:
|
412 |
+
processor.process_sample(sample)
|
413 |
+
|
414 |
+
processor.log_generation_time()
|
415 |
+
|
416 |
+
if cfg.decoding.results_path is not None:
|
417 |
+
processor.merge_shards()
|
418 |
+
|
419 |
+
errs_t, leng_t = processor.total_errors, processor.total_length
|
420 |
+
|
421 |
+
if cfg.common.cpu:
|
422 |
+
logger.warning("Merging WER requires CUDA.")
|
423 |
+
elif processor.data_parallel_world_size > 1:
|
424 |
+
stats = torch.LongTensor([errs_t, leng_t]).cuda()
|
425 |
+
dist.all_reduce(stats, op=dist.ReduceOp.SUM)
|
426 |
+
errs_t, leng_t = stats[0].item(), stats[1].item()
|
427 |
+
|
428 |
+
wer = errs_t * 100.0 / leng_t
|
429 |
+
|
430 |
+
if distributed_utils.is_master(cfg.distributed_training):
|
431 |
+
with open(wer_file, "w") as f:
|
432 |
+
f.write(
|
433 |
+
(
|
434 |
+
f"WER: {wer}\n"
|
435 |
+
f"err / num_ref_words = {errs_t} / {leng_t}\n\n"
|
436 |
+
f"{yaml_str}"
|
437 |
+
)
|
438 |
+
)
|
439 |
+
|
440 |
+
return wer
|
441 |
+
|
442 |
+
|
443 |
+
@hydra.main(config_path=config_path, config_name="infer")
|
444 |
+
def hydra_main(cfg: InferConfig) -> Union[float, Tuple[float, Optional[float]]]:
|
445 |
+
container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)
|
446 |
+
cfg = OmegaConf.create(container)
|
447 |
+
OmegaConf.set_struct(cfg, True)
|
448 |
+
|
449 |
+
if cfg.common.reset_logging:
|
450 |
+
reset_logging()
|
451 |
+
|
452 |
+
utils.import_user_module(cfg.common)
|
453 |
+
|
454 |
+
# logger.info("Config:\n%s", OmegaConf.to_yaml(cfg))
|
455 |
+
wer = float("inf")
|
456 |
+
|
457 |
+
try:
|
458 |
+
if cfg.common.profile:
|
459 |
+
with torch.cuda.profiler.profile():
|
460 |
+
with torch.autograd.profiler.emit_nvtx():
|
461 |
+
distributed_utils.call_main(cfg, main)
|
462 |
+
else:
|
463 |
+
distributed_utils.call_main(cfg, main)
|
464 |
+
|
465 |
+
wer = parse_wer(get_wer_file(cfg))
|
466 |
+
except BaseException as e: # pylint: disable=broad-except
|
467 |
+
if not cfg.common.suppress_crashes:
|
468 |
+
raise
|
469 |
+
else:
|
470 |
+
logger.error("Crashed! %s", str(e))
|
471 |
+
|
472 |
+
logger.info("Word error rate: %.4f", wer)
|
473 |
+
if cfg.is_ax:
|
474 |
+
return wer, None
|
475 |
+
|
476 |
+
return wer
|
477 |
+
|
478 |
+
|
479 |
+
def cli_main() -> None:
|
480 |
+
try:
|
481 |
+
from hydra._internal.utils import (
|
482 |
+
get_args,
|
483 |
+
) # pylint: disable=import-outside-toplevel
|
484 |
+
|
485 |
+
cfg_name = get_args().config_name or "infer"
|
486 |
+
except ImportError:
|
487 |
+
logger.warning("Failed to get config name from hydra args")
|
488 |
+
cfg_name = "infer"
|
489 |
+
|
490 |
+
cs = ConfigStore.instance()
|
491 |
+
cs.store(name=cfg_name, node=InferConfig)
|
492 |
+
|
493 |
+
for k in InferConfig.__dataclass_fields__:
|
494 |
+
if is_dataclass(InferConfig.__dataclass_fields__[k].type):
|
495 |
+
v = InferConfig.__dataclass_fields__[k].default
|
496 |
+
cs.store(name=k, node=v)
|
497 |
+
|
498 |
+
hydra_main() # pylint: disable=no-value-for-parameter
|
499 |
+
|
500 |
+
|
501 |
+
if __name__ == "__main__":
|
502 |
+
cli_main()
|
fairseq/examples/speech_recognition/tasks/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import os
|
3 |
+
|
4 |
+
|
5 |
+
for file in sorted(os.listdir(os.path.dirname(__file__))):
|
6 |
+
if file.endswith(".py") and not file.startswith("_"):
|
7 |
+
task_name = file[: file.find(".py")]
|
8 |
+
importlib.import_module("examples.speech_recognition.tasks." + task_name)
|
fairseq/examples/speech_recognition/tasks/speech_recognition.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import re
|
9 |
+
import sys
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from examples.speech_recognition.data import AsrDataset
|
13 |
+
from examples.speech_recognition.data.replabels import replabel_symbol
|
14 |
+
from fairseq.data import Dictionary
|
15 |
+
from fairseq.tasks import LegacyFairseqTask, register_task
|
16 |
+
|
17 |
+
|
18 |
+
def get_asr_dataset_from_json(data_json_path, tgt_dict):
|
19 |
+
"""
|
20 |
+
Parse data json and create dataset.
|
21 |
+
See scripts/asr_prep_json.py which pack json from raw files
|
22 |
+
|
23 |
+
Json example:
|
24 |
+
{
|
25 |
+
"utts": {
|
26 |
+
"4771-29403-0025": {
|
27 |
+
"input": {
|
28 |
+
"length_ms": 170,
|
29 |
+
"path": "/tmp/file1.flac"
|
30 |
+
},
|
31 |
+
"output": {
|
32 |
+
"text": "HELLO \n",
|
33 |
+
"token": "HE LLO",
|
34 |
+
"tokenid": "4815, 861"
|
35 |
+
}
|
36 |
+
},
|
37 |
+
"1564-142299-0096": {
|
38 |
+
...
|
39 |
+
}
|
40 |
+
}
|
41 |
+
"""
|
42 |
+
if not os.path.isfile(data_json_path):
|
43 |
+
raise FileNotFoundError("Dataset not found: {}".format(data_json_path))
|
44 |
+
with open(data_json_path, "rb") as f:
|
45 |
+
data_samples = json.load(f)["utts"]
|
46 |
+
assert len(data_samples) != 0
|
47 |
+
sorted_samples = sorted(
|
48 |
+
data_samples.items(),
|
49 |
+
key=lambda sample: int(sample[1]["input"]["length_ms"]),
|
50 |
+
reverse=True,
|
51 |
+
)
|
52 |
+
aud_paths = [s[1]["input"]["path"] for s in sorted_samples]
|
53 |
+
ids = [s[0] for s in sorted_samples]
|
54 |
+
speakers = []
|
55 |
+
for s in sorted_samples:
|
56 |
+
m = re.search("(.+?)-(.+?)-(.+?)", s[0])
|
57 |
+
speakers.append(m.group(1) + "_" + m.group(2))
|
58 |
+
frame_sizes = [s[1]["input"]["length_ms"] for s in sorted_samples]
|
59 |
+
tgt = [
|
60 |
+
[int(i) for i in s[1]["output"]["tokenid"].split(", ")]
|
61 |
+
for s in sorted_samples
|
62 |
+
]
|
63 |
+
# append eos
|
64 |
+
tgt = [[*t, tgt_dict.eos()] for t in tgt]
|
65 |
+
return AsrDataset(aud_paths, frame_sizes, tgt, tgt_dict, ids, speakers)
|
66 |
+
|
67 |
+
|
68 |
+
@register_task("speech_recognition")
|
69 |
+
class SpeechRecognitionTask(LegacyFairseqTask):
|
70 |
+
"""
|
71 |
+
Task for training speech recognition model.
|
72 |
+
"""
|
73 |
+
|
74 |
+
@staticmethod
|
75 |
+
def add_args(parser):
|
76 |
+
"""Add task-specific arguments to the parser."""
|
77 |
+
parser.add_argument("data", help="path to data directory")
|
78 |
+
parser.add_argument(
|
79 |
+
"--silence-token", default="\u2581", help="token for silence (used by w2l)"
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"--max-source-positions",
|
83 |
+
default=sys.maxsize,
|
84 |
+
type=int,
|
85 |
+
metavar="N",
|
86 |
+
help="max number of frames in the source sequence",
|
87 |
+
)
|
88 |
+
parser.add_argument(
|
89 |
+
"--max-target-positions",
|
90 |
+
default=1024,
|
91 |
+
type=int,
|
92 |
+
metavar="N",
|
93 |
+
help="max number of tokens in the target sequence",
|
94 |
+
)
|
95 |
+
|
96 |
+
def __init__(self, args, tgt_dict):
|
97 |
+
super().__init__(args)
|
98 |
+
self.tgt_dict = tgt_dict
|
99 |
+
|
100 |
+
@classmethod
|
101 |
+
def setup_task(cls, args, **kwargs):
|
102 |
+
"""Setup the task (e.g., load dictionaries)."""
|
103 |
+
dict_path = os.path.join(args.data, "dict.txt")
|
104 |
+
if not os.path.isfile(dict_path):
|
105 |
+
raise FileNotFoundError("Dict not found: {}".format(dict_path))
|
106 |
+
tgt_dict = Dictionary.load(dict_path)
|
107 |
+
|
108 |
+
if args.criterion == "ctc_loss":
|
109 |
+
tgt_dict.add_symbol("<ctc_blank>")
|
110 |
+
elif args.criterion == "asg_loss":
|
111 |
+
for i in range(1, args.max_replabel + 1):
|
112 |
+
tgt_dict.add_symbol(replabel_symbol(i))
|
113 |
+
|
114 |
+
print("| dictionary: {} types".format(len(tgt_dict)))
|
115 |
+
return cls(args, tgt_dict)
|
116 |
+
|
117 |
+
def load_dataset(self, split, combine=False, **kwargs):
|
118 |
+
"""Load a given dataset split.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
split (str): name of the split (e.g., train, valid, test)
|
122 |
+
"""
|
123 |
+
data_json_path = os.path.join(self.args.data, "{}.json".format(split))
|
124 |
+
self.datasets[split] = get_asr_dataset_from_json(data_json_path, self.tgt_dict)
|
125 |
+
|
126 |
+
def build_generator(self, models, args, **unused):
|
127 |
+
w2l_decoder = getattr(args, "w2l_decoder", None)
|
128 |
+
if w2l_decoder == "viterbi":
|
129 |
+
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
|
130 |
+
|
131 |
+
return W2lViterbiDecoder(args, self.target_dictionary)
|
132 |
+
elif w2l_decoder == "kenlm":
|
133 |
+
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
|
134 |
+
|
135 |
+
return W2lKenLMDecoder(args, self.target_dictionary)
|
136 |
+
elif w2l_decoder == "fairseqlm":
|
137 |
+
from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder
|
138 |
+
|
139 |
+
return W2lFairseqLMDecoder(args, self.target_dictionary)
|
140 |
+
else:
|
141 |
+
return super().build_generator(models, args)
|
142 |
+
|
143 |
+
@property
|
144 |
+
def target_dictionary(self):
|
145 |
+
"""Return the :class:`~fairseq.data.Dictionary` for the language
|
146 |
+
model."""
|
147 |
+
return self.tgt_dict
|
148 |
+
|
149 |
+
@property
|
150 |
+
def source_dictionary(self):
|
151 |
+
"""Return the source :class:`~fairseq.data.Dictionary` (if applicable
|
152 |
+
for this task)."""
|
153 |
+
return None
|
154 |
+
|
155 |
+
def max_positions(self):
|
156 |
+
"""Return the max speech and sentence length allowed by the task."""
|
157 |
+
return (self.args.max_source_positions, self.args.max_target_positions)
|
fairseq/examples/speech_recognition/utils/wer_utils.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the MIT license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
9 |
+
|
10 |
+
import re
|
11 |
+
from collections import deque
|
12 |
+
from enum import Enum
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
|
17 |
+
"""
|
18 |
+
Utility modules for computation of Word Error Rate,
|
19 |
+
Alignments, as well as more granular metrics like
|
20 |
+
deletion, insersion and substitutions.
|
21 |
+
"""
|
22 |
+
|
23 |
+
|
24 |
+
class Code(Enum):
|
25 |
+
match = 1
|
26 |
+
substitution = 2
|
27 |
+
insertion = 3
|
28 |
+
deletion = 4
|
29 |
+
|
30 |
+
|
31 |
+
class Token(object):
|
32 |
+
def __init__(self, lbl="", st=np.nan, en=np.nan):
|
33 |
+
if np.isnan(st):
|
34 |
+
self.label, self.start, self.end = "", 0.0, 0.0
|
35 |
+
else:
|
36 |
+
self.label, self.start, self.end = lbl, st, en
|
37 |
+
|
38 |
+
|
39 |
+
class AlignmentResult(object):
|
40 |
+
def __init__(self, refs, hyps, codes, score):
|
41 |
+
self.refs = refs # std::deque<int>
|
42 |
+
self.hyps = hyps # std::deque<int>
|
43 |
+
self.codes = codes # std::deque<Code>
|
44 |
+
self.score = score # float
|
45 |
+
|
46 |
+
|
47 |
+
def coordinate_to_offset(row, col, ncols):
|
48 |
+
return int(row * ncols + col)
|
49 |
+
|
50 |
+
|
51 |
+
def offset_to_row(offset, ncols):
|
52 |
+
return int(offset / ncols)
|
53 |
+
|
54 |
+
|
55 |
+
def offset_to_col(offset, ncols):
|
56 |
+
return int(offset % ncols)
|
57 |
+
|
58 |
+
|
59 |
+
def trimWhitespace(str):
|
60 |
+
return re.sub(" +", " ", re.sub(" *$", "", re.sub("^ *", "", str)))
|
61 |
+
|
62 |
+
|
63 |
+
def str2toks(str):
|
64 |
+
pieces = trimWhitespace(str).split(" ")
|
65 |
+
toks = []
|
66 |
+
for p in pieces:
|
67 |
+
toks.append(Token(p, 0.0, 0.0))
|
68 |
+
return toks
|
69 |
+
|
70 |
+
|
71 |
+
class EditDistance(object):
|
72 |
+
def __init__(self, time_mediated):
|
73 |
+
self.time_mediated_ = time_mediated
|
74 |
+
self.scores_ = np.nan # Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic>
|
75 |
+
self.backtraces_ = (
|
76 |
+
np.nan
|
77 |
+
) # Eigen::Matrix<size_t, Eigen::Dynamic, Eigen::Dynamic> backtraces_;
|
78 |
+
self.confusion_pairs_ = {}
|
79 |
+
|
80 |
+
def cost(self, ref, hyp, code):
|
81 |
+
if self.time_mediated_:
|
82 |
+
if code == Code.match:
|
83 |
+
return abs(ref.start - hyp.start) + abs(ref.end - hyp.end)
|
84 |
+
elif code == Code.insertion:
|
85 |
+
return hyp.end - hyp.start
|
86 |
+
elif code == Code.deletion:
|
87 |
+
return ref.end - ref.start
|
88 |
+
else: # substitution
|
89 |
+
return abs(ref.start - hyp.start) + abs(ref.end - hyp.end) + 0.1
|
90 |
+
else:
|
91 |
+
if code == Code.match:
|
92 |
+
return 0
|
93 |
+
elif code == Code.insertion or code == Code.deletion:
|
94 |
+
return 3
|
95 |
+
else: # substitution
|
96 |
+
return 4
|
97 |
+
|
98 |
+
def get_result(self, refs, hyps):
|
99 |
+
res = AlignmentResult(refs=deque(), hyps=deque(), codes=deque(), score=np.nan)
|
100 |
+
|
101 |
+
num_rows, num_cols = self.scores_.shape
|
102 |
+
res.score = self.scores_[num_rows - 1, num_cols - 1]
|
103 |
+
|
104 |
+
curr_offset = coordinate_to_offset(num_rows - 1, num_cols - 1, num_cols)
|
105 |
+
|
106 |
+
while curr_offset != 0:
|
107 |
+
curr_row = offset_to_row(curr_offset, num_cols)
|
108 |
+
curr_col = offset_to_col(curr_offset, num_cols)
|
109 |
+
|
110 |
+
prev_offset = self.backtraces_[curr_row, curr_col]
|
111 |
+
|
112 |
+
prev_row = offset_to_row(prev_offset, num_cols)
|
113 |
+
prev_col = offset_to_col(prev_offset, num_cols)
|
114 |
+
|
115 |
+
res.refs.appendleft(curr_row - 1) # Note: this was .push_front() in C++
|
116 |
+
res.hyps.appendleft(curr_col - 1)
|
117 |
+
if curr_row - 1 == prev_row and curr_col == prev_col:
|
118 |
+
res.codes.appendleft(Code.deletion)
|
119 |
+
elif curr_row == prev_row and curr_col - 1 == prev_col:
|
120 |
+
res.codes.appendleft(Code.insertion)
|
121 |
+
else:
|
122 |
+
# assert(curr_row - 1 == prev_row and curr_col - 1 == prev_col)
|
123 |
+
ref_str = refs[res.refs[0]].label
|
124 |
+
hyp_str = hyps[res.hyps[0]].label
|
125 |
+
|
126 |
+
if ref_str == hyp_str:
|
127 |
+
res.codes.appendleft(Code.match)
|
128 |
+
else:
|
129 |
+
res.codes.appendleft(Code.substitution)
|
130 |
+
|
131 |
+
confusion_pair = "%s -> %s" % (ref_str, hyp_str)
|
132 |
+
if confusion_pair not in self.confusion_pairs_:
|
133 |
+
self.confusion_pairs_[confusion_pair] = 1
|
134 |
+
else:
|
135 |
+
self.confusion_pairs_[confusion_pair] += 1
|
136 |
+
|
137 |
+
curr_offset = prev_offset
|
138 |
+
|
139 |
+
return res
|
140 |
+
|
141 |
+
def align(self, refs, hyps):
|
142 |
+
if len(refs) == 0 and len(hyps) == 0:
|
143 |
+
return np.nan
|
144 |
+
|
145 |
+
# NOTE: we're not resetting the values in these matrices because every value
|
146 |
+
# will be overridden in the loop below. If this assumption doesn't hold,
|
147 |
+
# be sure to set all entries in self.scores_ and self.backtraces_ to 0.
|
148 |
+
self.scores_ = np.zeros((len(refs) + 1, len(hyps) + 1))
|
149 |
+
self.backtraces_ = np.zeros((len(refs) + 1, len(hyps) + 1))
|
150 |
+
|
151 |
+
num_rows, num_cols = self.scores_.shape
|
152 |
+
|
153 |
+
for i in range(num_rows):
|
154 |
+
for j in range(num_cols):
|
155 |
+
if i == 0 and j == 0:
|
156 |
+
self.scores_[i, j] = 0.0
|
157 |
+
self.backtraces_[i, j] = 0
|
158 |
+
continue
|
159 |
+
|
160 |
+
if i == 0:
|
161 |
+
self.scores_[i, j] = self.scores_[i, j - 1] + self.cost(
|
162 |
+
None, hyps[j - 1], Code.insertion
|
163 |
+
)
|
164 |
+
self.backtraces_[i, j] = coordinate_to_offset(i, j - 1, num_cols)
|
165 |
+
continue
|
166 |
+
|
167 |
+
if j == 0:
|
168 |
+
self.scores_[i, j] = self.scores_[i - 1, j] + self.cost(
|
169 |
+
refs[i - 1], None, Code.deletion
|
170 |
+
)
|
171 |
+
self.backtraces_[i, j] = coordinate_to_offset(i - 1, j, num_cols)
|
172 |
+
continue
|
173 |
+
|
174 |
+
# Below here both i and j are greater than 0
|
175 |
+
ref = refs[i - 1]
|
176 |
+
hyp = hyps[j - 1]
|
177 |
+
best_score = self.scores_[i - 1, j - 1] + (
|
178 |
+
self.cost(ref, hyp, Code.match)
|
179 |
+
if (ref.label == hyp.label)
|
180 |
+
else self.cost(ref, hyp, Code.substitution)
|
181 |
+
)
|
182 |
+
|
183 |
+
prev_row = i - 1
|
184 |
+
prev_col = j - 1
|
185 |
+
ins = self.scores_[i, j - 1] + self.cost(None, hyp, Code.insertion)
|
186 |
+
if ins < best_score:
|
187 |
+
best_score = ins
|
188 |
+
prev_row = i
|
189 |
+
prev_col = j - 1
|
190 |
+
|
191 |
+
delt = self.scores_[i - 1, j] + self.cost(ref, None, Code.deletion)
|
192 |
+
if delt < best_score:
|
193 |
+
best_score = delt
|
194 |
+
prev_row = i - 1
|
195 |
+
prev_col = j
|
196 |
+
|
197 |
+
self.scores_[i, j] = best_score
|
198 |
+
self.backtraces_[i, j] = coordinate_to_offset(
|
199 |
+
prev_row, prev_col, num_cols
|
200 |
+
)
|
201 |
+
|
202 |
+
return self.get_result(refs, hyps)
|
203 |
+
|
204 |
+
|
205 |
+
class WERTransformer(object):
|
206 |
+
def __init__(self, hyp_str, ref_str, verbose=True):
|
207 |
+
self.ed_ = EditDistance(False)
|
208 |
+
self.id2oracle_errs_ = {}
|
209 |
+
self.utts_ = 0
|
210 |
+
self.words_ = 0
|
211 |
+
self.insertions_ = 0
|
212 |
+
self.deletions_ = 0
|
213 |
+
self.substitutions_ = 0
|
214 |
+
|
215 |
+
self.process(["dummy_str", hyp_str, ref_str])
|
216 |
+
|
217 |
+
if verbose:
|
218 |
+
print("'%s' vs '%s'" % (hyp_str, ref_str))
|
219 |
+
self.report_result()
|
220 |
+
|
221 |
+
def process(self, input): # std::vector<std::string>&& input
|
222 |
+
if len(input) < 3:
|
223 |
+
print(
|
224 |
+
"Input must be of the form <id> ... <hypo> <ref> , got ",
|
225 |
+
len(input),
|
226 |
+
" inputs:",
|
227 |
+
)
|
228 |
+
return None
|
229 |
+
|
230 |
+
# Align
|
231 |
+
# std::vector<Token> hyps;
|
232 |
+
# std::vector<Token> refs;
|
233 |
+
|
234 |
+
hyps = str2toks(input[-2])
|
235 |
+
refs = str2toks(input[-1])
|
236 |
+
|
237 |
+
alignment = self.ed_.align(refs, hyps)
|
238 |
+
if alignment is None:
|
239 |
+
print("Alignment is null")
|
240 |
+
return np.nan
|
241 |
+
|
242 |
+
# Tally errors
|
243 |
+
ins = 0
|
244 |
+
dels = 0
|
245 |
+
subs = 0
|
246 |
+
for code in alignment.codes:
|
247 |
+
if code == Code.substitution:
|
248 |
+
subs += 1
|
249 |
+
elif code == Code.insertion:
|
250 |
+
ins += 1
|
251 |
+
elif code == Code.deletion:
|
252 |
+
dels += 1
|
253 |
+
|
254 |
+
# Output
|
255 |
+
row = input
|
256 |
+
row.append(str(len(refs)))
|
257 |
+
row.append(str(ins))
|
258 |
+
row.append(str(dels))
|
259 |
+
row.append(str(subs))
|
260 |
+
# print(row)
|
261 |
+
|
262 |
+
# Accumulate
|
263 |
+
kIdIndex = 0
|
264 |
+
kNBestSep = "/"
|
265 |
+
|
266 |
+
pieces = input[kIdIndex].split(kNBestSep)
|
267 |
+
|
268 |
+
if len(pieces) == 0:
|
269 |
+
print(
|
270 |
+
"Error splitting ",
|
271 |
+
input[kIdIndex],
|
272 |
+
" on '",
|
273 |
+
kNBestSep,
|
274 |
+
"', got empty list",
|
275 |
+
)
|
276 |
+
return np.nan
|
277 |
+
|
278 |
+
id = pieces[0]
|
279 |
+
if id not in self.id2oracle_errs_:
|
280 |
+
self.utts_ += 1
|
281 |
+
self.words_ += len(refs)
|
282 |
+
self.insertions_ += ins
|
283 |
+
self.deletions_ += dels
|
284 |
+
self.substitutions_ += subs
|
285 |
+
self.id2oracle_errs_[id] = [ins, dels, subs]
|
286 |
+
else:
|
287 |
+
curr_err = ins + dels + subs
|
288 |
+
prev_err = np.sum(self.id2oracle_errs_[id])
|
289 |
+
if curr_err < prev_err:
|
290 |
+
self.id2oracle_errs_[id] = [ins, dels, subs]
|
291 |
+
|
292 |
+
return 0
|
293 |
+
|
294 |
+
def report_result(self):
|
295 |
+
# print("---------- Summary ---------------")
|
296 |
+
if self.words_ == 0:
|
297 |
+
print("No words counted")
|
298 |
+
return
|
299 |
+
|
300 |
+
# 1-best
|
301 |
+
best_wer = (
|
302 |
+
100.0
|
303 |
+
* (self.insertions_ + self.deletions_ + self.substitutions_)
|
304 |
+
/ self.words_
|
305 |
+
)
|
306 |
+
|
307 |
+
print(
|
308 |
+
"\tWER = %0.2f%% (%i utts, %i words, %0.2f%% ins, "
|
309 |
+
"%0.2f%% dels, %0.2f%% subs)"
|
310 |
+
% (
|
311 |
+
best_wer,
|
312 |
+
self.utts_,
|
313 |
+
self.words_,
|
314 |
+
100.0 * self.insertions_ / self.words_,
|
315 |
+
100.0 * self.deletions_ / self.words_,
|
316 |
+
100.0 * self.substitutions_ / self.words_,
|
317 |
+
)
|
318 |
+
)
|
319 |
+
|
320 |
+
def wer(self):
|
321 |
+
if self.words_ == 0:
|
322 |
+
wer = np.nan
|
323 |
+
else:
|
324 |
+
wer = (
|
325 |
+
100.0
|
326 |
+
* (self.insertions_ + self.deletions_ + self.substitutions_)
|
327 |
+
/ self.words_
|
328 |
+
)
|
329 |
+
return wer
|
330 |
+
|
331 |
+
def stats(self):
|
332 |
+
if self.words_ == 0:
|
333 |
+
stats = {}
|
334 |
+
else:
|
335 |
+
wer = (
|
336 |
+
100.0
|
337 |
+
* (self.insertions_ + self.deletions_ + self.substitutions_)
|
338 |
+
/ self.words_
|
339 |
+
)
|
340 |
+
stats = dict(
|
341 |
+
{
|
342 |
+
"wer": wer,
|
343 |
+
"utts": self.utts_,
|
344 |
+
"numwords": self.words_,
|
345 |
+
"ins": self.insertions_,
|
346 |
+
"dels": self.deletions_,
|
347 |
+
"subs": self.substitutions_,
|
348 |
+
"confusion_pairs": self.ed_.confusion_pairs_,
|
349 |
+
}
|
350 |
+
)
|
351 |
+
return stats
|
352 |
+
|
353 |
+
|
354 |
+
def calc_wer(hyp_str, ref_str):
|
355 |
+
t = WERTransformer(hyp_str, ref_str, verbose=0)
|
356 |
+
return t.wer()
|
357 |
+
|
358 |
+
|
359 |
+
def calc_wer_stats(hyp_str, ref_str):
|
360 |
+
t = WERTransformer(hyp_str, ref_str, verbose=0)
|
361 |
+
return t.stats()
|
362 |
+
|
363 |
+
|
364 |
+
def get_wer_alignment_codes(hyp_str, ref_str):
|
365 |
+
"""
|
366 |
+
INPUT: hypothesis string, reference string
|
367 |
+
OUTPUT: List of alignment codes (intermediate results from WER computation)
|
368 |
+
"""
|
369 |
+
t = WERTransformer(hyp_str, ref_str, verbose=0)
|
370 |
+
return t.ed_.align(str2toks(ref_str), str2toks(hyp_str)).codes
|
371 |
+
|
372 |
+
|
373 |
+
def merge_counts(x, y):
|
374 |
+
# Merge two hashes which have 'counts' as their values
|
375 |
+
# This can be used for example to merge confusion pair counts
|
376 |
+
# conf_pairs = merge_counts(conf_pairs, stats['confusion_pairs'])
|
377 |
+
for k, v in y.items():
|
378 |
+
if k not in x:
|
379 |
+
x[k] = 0
|
380 |
+
x[k] += v
|
381 |
+
return x
|
fairseq/examples/speech_recognition/w2l_decoder.py
ADDED
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the MIT license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
"""
|
9 |
+
Flashlight decoders.
|
10 |
+
"""
|
11 |
+
|
12 |
+
import gc
|
13 |
+
import itertools as it
|
14 |
+
import os.path as osp
|
15 |
+
from typing import List
|
16 |
+
import warnings
|
17 |
+
from collections import deque, namedtuple
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
from examples.speech_recognition.data.replabels import unpack_replabels
|
22 |
+
from fairseq import tasks
|
23 |
+
from fairseq.utils import apply_to_sample
|
24 |
+
from omegaconf import open_dict
|
25 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
26 |
+
|
27 |
+
|
28 |
+
try:
|
29 |
+
from flashlight.lib.text.dictionary import create_word_dict, load_words
|
30 |
+
from flashlight.lib.sequence.criterion import CpuViterbiPath, get_data_ptr_as_bytes
|
31 |
+
from flashlight.lib.text.decoder import (
|
32 |
+
CriterionType,
|
33 |
+
LexiconDecoderOptions,
|
34 |
+
KenLM,
|
35 |
+
LM,
|
36 |
+
LMState,
|
37 |
+
SmearingMode,
|
38 |
+
Trie,
|
39 |
+
LexiconDecoder,
|
40 |
+
)
|
41 |
+
except:
|
42 |
+
warnings.warn(
|
43 |
+
"flashlight python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/flashlight/tree/master/bindings/python"
|
44 |
+
)
|
45 |
+
LM = object
|
46 |
+
LMState = object
|
47 |
+
|
48 |
+
|
49 |
+
class W2lDecoder(object):
|
50 |
+
def __init__(self, args, tgt_dict):
|
51 |
+
self.tgt_dict = tgt_dict
|
52 |
+
self.vocab_size = len(tgt_dict)
|
53 |
+
self.nbest = args.nbest
|
54 |
+
|
55 |
+
# criterion-specific init
|
56 |
+
self.criterion_type = CriterionType.CTC
|
57 |
+
self.blank = (
|
58 |
+
tgt_dict.index("<ctc_blank>")
|
59 |
+
if "<ctc_blank>" in tgt_dict.indices
|
60 |
+
else tgt_dict.bos()
|
61 |
+
)
|
62 |
+
if "<sep>" in tgt_dict.indices:
|
63 |
+
self.silence = tgt_dict.index("<sep>")
|
64 |
+
elif "|" in tgt_dict.indices:
|
65 |
+
self.silence = tgt_dict.index("|")
|
66 |
+
else:
|
67 |
+
self.silence = tgt_dict.eos()
|
68 |
+
self.asg_transitions = None
|
69 |
+
|
70 |
+
def generate(self, models, sample, **unused):
|
71 |
+
"""Generate a batch of inferences."""
|
72 |
+
# model.forward normally channels prev_output_tokens into the decoder
|
73 |
+
# separately, but SequenceGenerator directly calls model.encoder
|
74 |
+
encoder_input = {
|
75 |
+
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
|
76 |
+
}
|
77 |
+
emissions = self.get_emissions(models, encoder_input)
|
78 |
+
return self.decode(emissions)
|
79 |
+
|
80 |
+
def get_emissions(self, models, encoder_input):
|
81 |
+
"""Run encoder and normalize emissions"""
|
82 |
+
model = models[0]
|
83 |
+
encoder_out = model(**encoder_input)
|
84 |
+
if hasattr(model, "get_logits"):
|
85 |
+
emissions = model.get_logits(encoder_out) # no need to normalize emissions
|
86 |
+
else:
|
87 |
+
emissions = model.get_normalized_probs(encoder_out, log_probs=True)
|
88 |
+
return emissions.transpose(0, 1).float().cpu().contiguous()
|
89 |
+
|
90 |
+
def get_tokens(self, idxs):
|
91 |
+
"""Normalize tokens by handling CTC blank, ASG replabels, etc."""
|
92 |
+
idxs = (g[0] for g in it.groupby(idxs))
|
93 |
+
idxs = filter(lambda x: x != self.blank, idxs)
|
94 |
+
return torch.LongTensor(list(idxs))
|
95 |
+
|
96 |
+
|
97 |
+
class W2lViterbiDecoder(W2lDecoder):
|
98 |
+
def __init__(self, args, tgt_dict):
|
99 |
+
super().__init__(args, tgt_dict)
|
100 |
+
|
101 |
+
def decode(self, emissions):
|
102 |
+
B, T, N = emissions.size()
|
103 |
+
hypos = []
|
104 |
+
if self.asg_transitions is None:
|
105 |
+
transitions = torch.FloatTensor(N, N).zero_()
|
106 |
+
else:
|
107 |
+
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
|
108 |
+
viterbi_path = torch.IntTensor(B, T)
|
109 |
+
workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
|
110 |
+
CpuViterbiPath.compute(
|
111 |
+
B,
|
112 |
+
T,
|
113 |
+
N,
|
114 |
+
get_data_ptr_as_bytes(emissions),
|
115 |
+
get_data_ptr_as_bytes(transitions),
|
116 |
+
get_data_ptr_as_bytes(viterbi_path),
|
117 |
+
get_data_ptr_as_bytes(workspace),
|
118 |
+
)
|
119 |
+
return [
|
120 |
+
[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}]
|
121 |
+
for b in range(B)
|
122 |
+
]
|
123 |
+
|
124 |
+
|
125 |
+
class W2lKenLMDecoder(W2lDecoder):
|
126 |
+
def __init__(self, args, tgt_dict):
|
127 |
+
super().__init__(args, tgt_dict)
|
128 |
+
|
129 |
+
self.unit_lm = getattr(args, "unit_lm", False)
|
130 |
+
|
131 |
+
if args.lexicon:
|
132 |
+
self.lexicon = load_words(args.lexicon)
|
133 |
+
self.word_dict = create_word_dict(self.lexicon)
|
134 |
+
self.unk_word = self.word_dict.get_index("<unk>")
|
135 |
+
|
136 |
+
self.lm = KenLM(args.kenlm_model, self.word_dict)
|
137 |
+
self.trie = Trie(self.vocab_size, self.silence)
|
138 |
+
|
139 |
+
start_state = self.lm.start(False)
|
140 |
+
for i, (word, spellings) in enumerate(self.lexicon.items()):
|
141 |
+
word_idx = self.word_dict.get_index(word)
|
142 |
+
_, score = self.lm.score(start_state, word_idx)
|
143 |
+
for spelling in spellings:
|
144 |
+
spelling_idxs = [tgt_dict.index(token) for token in spelling]
|
145 |
+
assert (
|
146 |
+
tgt_dict.unk() not in spelling_idxs
|
147 |
+
), f"{spelling} {spelling_idxs}"
|
148 |
+
self.trie.insert(spelling_idxs, word_idx, score)
|
149 |
+
self.trie.smear(SmearingMode.MAX)
|
150 |
+
|
151 |
+
self.decoder_opts = LexiconDecoderOptions(
|
152 |
+
beam_size=args.beam,
|
153 |
+
beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))),
|
154 |
+
beam_threshold=args.beam_threshold,
|
155 |
+
lm_weight=args.lm_weight,
|
156 |
+
word_score=args.word_score,
|
157 |
+
unk_score=args.unk_weight,
|
158 |
+
sil_score=args.sil_weight,
|
159 |
+
log_add=False,
|
160 |
+
criterion_type=self.criterion_type,
|
161 |
+
)
|
162 |
+
|
163 |
+
if self.asg_transitions is None:
|
164 |
+
N = 768
|
165 |
+
# self.asg_transitions = torch.FloatTensor(N, N).zero_()
|
166 |
+
self.asg_transitions = []
|
167 |
+
|
168 |
+
self.decoder = LexiconDecoder(
|
169 |
+
self.decoder_opts,
|
170 |
+
self.trie,
|
171 |
+
self.lm,
|
172 |
+
self.silence,
|
173 |
+
self.blank,
|
174 |
+
self.unk_word,
|
175 |
+
self.asg_transitions,
|
176 |
+
self.unit_lm,
|
177 |
+
)
|
178 |
+
else:
|
179 |
+
assert args.unit_lm, "lexicon free decoding can only be done with a unit language model"
|
180 |
+
from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions
|
181 |
+
|
182 |
+
d = {w: [[w]] for w in tgt_dict.symbols}
|
183 |
+
self.word_dict = create_word_dict(d)
|
184 |
+
self.lm = KenLM(args.kenlm_model, self.word_dict)
|
185 |
+
self.decoder_opts = LexiconFreeDecoderOptions(
|
186 |
+
beam_size=args.beam,
|
187 |
+
beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))),
|
188 |
+
beam_threshold=args.beam_threshold,
|
189 |
+
lm_weight=args.lm_weight,
|
190 |
+
sil_score=args.sil_weight,
|
191 |
+
log_add=False,
|
192 |
+
criterion_type=self.criterion_type,
|
193 |
+
)
|
194 |
+
self.decoder = LexiconFreeDecoder(
|
195 |
+
self.decoder_opts, self.lm, self.silence, self.blank, []
|
196 |
+
)
|
197 |
+
|
198 |
+
def get_timesteps(self, token_idxs: List[int]) -> List[int]:
|
199 |
+
"""Returns frame numbers corresponding to every non-blank token.
|
200 |
+
|
201 |
+
Parameters
|
202 |
+
----------
|
203 |
+
token_idxs : List[int]
|
204 |
+
IDs of decoded tokens.
|
205 |
+
|
206 |
+
Returns
|
207 |
+
-------
|
208 |
+
List[int]
|
209 |
+
Frame numbers corresponding to every non-blank token.
|
210 |
+
"""
|
211 |
+
timesteps = []
|
212 |
+
for i, token_idx in enumerate(token_idxs):
|
213 |
+
if token_idx == self.blank:
|
214 |
+
continue
|
215 |
+
if i == 0 or token_idx != token_idxs[i-1]:
|
216 |
+
timesteps.append(i)
|
217 |
+
return timesteps
|
218 |
+
|
219 |
+
def decode(self, emissions):
|
220 |
+
B, T, N = emissions.size()
|
221 |
+
hypos = []
|
222 |
+
for b in range(B):
|
223 |
+
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
|
224 |
+
results = self.decoder.decode(emissions_ptr, T, N)
|
225 |
+
|
226 |
+
nbest_results = results[: self.nbest]
|
227 |
+
hypos.append(
|
228 |
+
[
|
229 |
+
{
|
230 |
+
"tokens": self.get_tokens(result.tokens),
|
231 |
+
"score": result.score,
|
232 |
+
"timesteps": self.get_timesteps(result.tokens),
|
233 |
+
"words": [
|
234 |
+
self.word_dict.get_entry(x) for x in result.words if x >= 0
|
235 |
+
],
|
236 |
+
}
|
237 |
+
for result in nbest_results
|
238 |
+
]
|
239 |
+
)
|
240 |
+
return hypos
|
241 |
+
|
242 |
+
|
243 |
+
FairseqLMState = namedtuple("FairseqLMState", ["prefix", "incremental_state", "probs"])
|
244 |
+
|
245 |
+
|
246 |
+
class FairseqLM(LM):
|
247 |
+
def __init__(self, dictionary, model):
|
248 |
+
LM.__init__(self)
|
249 |
+
self.dictionary = dictionary
|
250 |
+
self.model = model
|
251 |
+
self.unk = self.dictionary.unk()
|
252 |
+
|
253 |
+
self.save_incremental = False # this currently does not work properly
|
254 |
+
self.max_cache = 20_000
|
255 |
+
|
256 |
+
model.cuda()
|
257 |
+
model.eval()
|
258 |
+
model.make_generation_fast_()
|
259 |
+
|
260 |
+
self.states = {}
|
261 |
+
self.stateq = deque()
|
262 |
+
|
263 |
+
def start(self, start_with_nothing):
|
264 |
+
state = LMState()
|
265 |
+
prefix = torch.LongTensor([[self.dictionary.eos()]])
|
266 |
+
incremental_state = {} if self.save_incremental else None
|
267 |
+
with torch.no_grad():
|
268 |
+
res = self.model(prefix.cuda(), incremental_state=incremental_state)
|
269 |
+
probs = self.model.get_normalized_probs(res, log_probs=True, sample=None)
|
270 |
+
|
271 |
+
if incremental_state is not None:
|
272 |
+
incremental_state = apply_to_sample(lambda x: x.cpu(), incremental_state)
|
273 |
+
self.states[state] = FairseqLMState(
|
274 |
+
prefix.numpy(), incremental_state, probs[0, -1].cpu().numpy()
|
275 |
+
)
|
276 |
+
self.stateq.append(state)
|
277 |
+
|
278 |
+
return state
|
279 |
+
|
280 |
+
def score(self, state: LMState, token_index: int, no_cache: bool = False):
|
281 |
+
"""
|
282 |
+
Evaluate language model based on the current lm state and new word
|
283 |
+
Parameters:
|
284 |
+
-----------
|
285 |
+
state: current lm state
|
286 |
+
token_index: index of the word
|
287 |
+
(can be lexicon index then you should store inside LM the
|
288 |
+
mapping between indices of lexicon and lm, or lm index of a word)
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
--------
|
292 |
+
(LMState, float): pair of (new state, score for the current word)
|
293 |
+
"""
|
294 |
+
curr_state = self.states[state]
|
295 |
+
|
296 |
+
def trim_cache(targ_size):
|
297 |
+
while len(self.stateq) > targ_size:
|
298 |
+
rem_k = self.stateq.popleft()
|
299 |
+
rem_st = self.states[rem_k]
|
300 |
+
rem_st = FairseqLMState(rem_st.prefix, None, None)
|
301 |
+
self.states[rem_k] = rem_st
|
302 |
+
|
303 |
+
if curr_state.probs is None:
|
304 |
+
new_incremental_state = (
|
305 |
+
curr_state.incremental_state.copy()
|
306 |
+
if curr_state.incremental_state is not None
|
307 |
+
else None
|
308 |
+
)
|
309 |
+
with torch.no_grad():
|
310 |
+
if new_incremental_state is not None:
|
311 |
+
new_incremental_state = apply_to_sample(
|
312 |
+
lambda x: x.cuda(), new_incremental_state
|
313 |
+
)
|
314 |
+
elif self.save_incremental:
|
315 |
+
new_incremental_state = {}
|
316 |
+
|
317 |
+
res = self.model(
|
318 |
+
torch.from_numpy(curr_state.prefix).cuda(),
|
319 |
+
incremental_state=new_incremental_state,
|
320 |
+
)
|
321 |
+
probs = self.model.get_normalized_probs(
|
322 |
+
res, log_probs=True, sample=None
|
323 |
+
)
|
324 |
+
|
325 |
+
if new_incremental_state is not None:
|
326 |
+
new_incremental_state = apply_to_sample(
|
327 |
+
lambda x: x.cpu(), new_incremental_state
|
328 |
+
)
|
329 |
+
|
330 |
+
curr_state = FairseqLMState(
|
331 |
+
curr_state.prefix, new_incremental_state, probs[0, -1].cpu().numpy()
|
332 |
+
)
|
333 |
+
|
334 |
+
if not no_cache:
|
335 |
+
self.states[state] = curr_state
|
336 |
+
self.stateq.append(state)
|
337 |
+
|
338 |
+
score = curr_state.probs[token_index].item()
|
339 |
+
|
340 |
+
trim_cache(self.max_cache)
|
341 |
+
|
342 |
+
outstate = state.child(token_index)
|
343 |
+
if outstate not in self.states and not no_cache:
|
344 |
+
prefix = np.concatenate(
|
345 |
+
[curr_state.prefix, torch.LongTensor([[token_index]])], -1
|
346 |
+
)
|
347 |
+
incr_state = curr_state.incremental_state
|
348 |
+
|
349 |
+
self.states[outstate] = FairseqLMState(prefix, incr_state, None)
|
350 |
+
|
351 |
+
if token_index == self.unk:
|
352 |
+
score = float("-inf")
|
353 |
+
|
354 |
+
return outstate, score
|
355 |
+
|
356 |
+
def finish(self, state: LMState):
|
357 |
+
"""
|
358 |
+
Evaluate eos for language model based on the current lm state
|
359 |
+
|
360 |
+
Returns:
|
361 |
+
--------
|
362 |
+
(LMState, float): pair of (new state, score for the current word)
|
363 |
+
"""
|
364 |
+
return self.score(state, self.dictionary.eos())
|
365 |
+
|
366 |
+
def empty_cache(self):
|
367 |
+
self.states = {}
|
368 |
+
self.stateq = deque()
|
369 |
+
gc.collect()
|
370 |
+
|
371 |
+
|
372 |
+
class W2lFairseqLMDecoder(W2lDecoder):
|
373 |
+
def __init__(self, args, tgt_dict):
|
374 |
+
super().__init__(args, tgt_dict)
|
375 |
+
|
376 |
+
self.unit_lm = getattr(args, "unit_lm", False)
|
377 |
+
|
378 |
+
self.lexicon = load_words(args.lexicon) if args.lexicon else None
|
379 |
+
self.idx_to_wrd = {}
|
380 |
+
|
381 |
+
checkpoint = torch.load(args.kenlm_model, map_location="cpu")
|
382 |
+
|
383 |
+
if "cfg" in checkpoint and checkpoint["cfg"] is not None:
|
384 |
+
lm_args = checkpoint["cfg"]
|
385 |
+
else:
|
386 |
+
lm_args = convert_namespace_to_omegaconf(checkpoint["args"])
|
387 |
+
|
388 |
+
with open_dict(lm_args.task):
|
389 |
+
lm_args.task.data = osp.dirname(args.kenlm_model)
|
390 |
+
|
391 |
+
task = tasks.setup_task(lm_args.task)
|
392 |
+
model = task.build_model(lm_args.model)
|
393 |
+
model.load_state_dict(checkpoint["model"], strict=False)
|
394 |
+
|
395 |
+
self.trie = Trie(self.vocab_size, self.silence)
|
396 |
+
|
397 |
+
self.word_dict = task.dictionary
|
398 |
+
self.unk_word = self.word_dict.unk()
|
399 |
+
self.lm = FairseqLM(self.word_dict, model)
|
400 |
+
|
401 |
+
if self.lexicon:
|
402 |
+
start_state = self.lm.start(False)
|
403 |
+
for i, (word, spellings) in enumerate(self.lexicon.items()):
|
404 |
+
if self.unit_lm:
|
405 |
+
word_idx = i
|
406 |
+
self.idx_to_wrd[i] = word
|
407 |
+
score = 0
|
408 |
+
else:
|
409 |
+
word_idx = self.word_dict.index(word)
|
410 |
+
_, score = self.lm.score(start_state, word_idx, no_cache=True)
|
411 |
+
|
412 |
+
for spelling in spellings:
|
413 |
+
spelling_idxs = [tgt_dict.index(token) for token in spelling]
|
414 |
+
assert (
|
415 |
+
tgt_dict.unk() not in spelling_idxs
|
416 |
+
), f"{spelling} {spelling_idxs}"
|
417 |
+
self.trie.insert(spelling_idxs, word_idx, score)
|
418 |
+
self.trie.smear(SmearingMode.MAX)
|
419 |
+
|
420 |
+
self.decoder_opts = LexiconDecoderOptions(
|
421 |
+
beam_size=args.beam,
|
422 |
+
beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))),
|
423 |
+
beam_threshold=args.beam_threshold,
|
424 |
+
lm_weight=args.lm_weight,
|
425 |
+
word_score=args.word_score,
|
426 |
+
unk_score=args.unk_weight,
|
427 |
+
sil_score=args.sil_weight,
|
428 |
+
log_add=False,
|
429 |
+
criterion_type=self.criterion_type,
|
430 |
+
)
|
431 |
+
|
432 |
+
self.decoder = LexiconDecoder(
|
433 |
+
self.decoder_opts,
|
434 |
+
self.trie,
|
435 |
+
self.lm,
|
436 |
+
self.silence,
|
437 |
+
self.blank,
|
438 |
+
self.unk_word,
|
439 |
+
[],
|
440 |
+
self.unit_lm,
|
441 |
+
)
|
442 |
+
else:
|
443 |
+
assert args.unit_lm, "lexicon free decoding can only be done with a unit language model"
|
444 |
+
from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions
|
445 |
+
|
446 |
+
d = {w: [[w]] for w in tgt_dict.symbols}
|
447 |
+
self.word_dict = create_word_dict(d)
|
448 |
+
self.lm = KenLM(args.kenlm_model, self.word_dict)
|
449 |
+
self.decoder_opts = LexiconFreeDecoderOptions(
|
450 |
+
beam_size=args.beam,
|
451 |
+
beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))),
|
452 |
+
beam_threshold=args.beam_threshold,
|
453 |
+
lm_weight=args.lm_weight,
|
454 |
+
sil_score=args.sil_weight,
|
455 |
+
log_add=False,
|
456 |
+
criterion_type=self.criterion_type,
|
457 |
+
)
|
458 |
+
self.decoder = LexiconFreeDecoder(
|
459 |
+
self.decoder_opts, self.lm, self.silence, self.blank, []
|
460 |
+
)
|
461 |
+
|
462 |
+
def decode(self, emissions):
|
463 |
+
B, T, N = emissions.size()
|
464 |
+
hypos = []
|
465 |
+
|
466 |
+
def idx_to_word(idx):
|
467 |
+
if self.unit_lm:
|
468 |
+
return self.idx_to_wrd[idx]
|
469 |
+
else:
|
470 |
+
return self.word_dict[idx]
|
471 |
+
|
472 |
+
def make_hypo(result):
|
473 |
+
hypo = {"tokens": self.get_tokens(result.tokens), "score": result.score}
|
474 |
+
if self.lexicon:
|
475 |
+
hypo["words"] = [idx_to_word(x) for x in result.words if x >= 0]
|
476 |
+
return hypo
|
477 |
+
|
478 |
+
for b in range(B):
|
479 |
+
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
|
480 |
+
results = self.decoder.decode(emissions_ptr, T, N)
|
481 |
+
|
482 |
+
nbest_results = results[: self.nbest]
|
483 |
+
hypos.append([make_hypo(result) for result in nbest_results])
|
484 |
+
self.lm.empty_cache()
|
485 |
+
|
486 |
+
return hypos
|
fairseq/examples/speech_synthesis/README.md
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Speech Synthesis (S^2)
|
2 |
+
===
|
3 |
+
[https://arxiv.org/abs/2109.06912](https://arxiv.org/abs/2109.06912)
|
4 |
+
|
5 |
+
Speech synthesis with fairseq.
|
6 |
+
|
7 |
+
## Features
|
8 |
+
|
9 |
+
- Autoregressive and non-autoregressive models
|
10 |
+
- Multi-speaker synthesis
|
11 |
+
- Audio preprocessing (denoising, VAD, etc.) for less curated data
|
12 |
+
- Automatic metrics for model development
|
13 |
+
- Similar data configuration as [S2T](../speech_to_text/README.md)
|
14 |
+
|
15 |
+
|
16 |
+
## Examples
|
17 |
+
- [Single-speaker synthesis on LJSpeech](docs/ljspeech_example.md)
|
18 |
+
- [Multi-speaker synthesis on VCTK](docs/vctk_example.md)
|
19 |
+
- [Multi-speaker synthesis on Common Voice](docs/common_voice_example.md)
|
20 |
+
|
21 |
+
|
22 |
+
## Citation
|
23 |
+
Please cite as:
|
24 |
+
```
|
25 |
+
@article{wang2021fairseqs2,
|
26 |
+
title={fairseq S\^{} 2: A Scalable and Integrable Speech Synthesis Toolkit},
|
27 |
+
author={Wang, Changhan and Hsu, Wei-Ning and Adi, Yossi and Polyak, Adam and Lee, Ann and Chen, Peng-Jen and Gu, Jiatao and Pino, Juan},
|
28 |
+
journal={arXiv preprint arXiv:2109.06912},
|
29 |
+
year={2021}
|
30 |
+
}
|
31 |
+
|
32 |
+
@inproceedings{ott2019fairseq,
|
33 |
+
title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
|
34 |
+
author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
|
35 |
+
booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
|
36 |
+
year = {2019},
|
37 |
+
}
|
38 |
+
```
|
fairseq/examples/speech_synthesis/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
fairseq/examples/speech_synthesis/data_utils.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import io
|
7 |
+
import os
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import Optional, List, Dict
|
10 |
+
import zipfile
|
11 |
+
import tempfile
|
12 |
+
from dataclasses import dataclass
|
13 |
+
from itertools import groupby
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
import numpy as np
|
18 |
+
from tqdm import tqdm
|
19 |
+
|
20 |
+
from examples.speech_to_text.data_utils import load_tsv_to_dicts
|
21 |
+
from fairseq.data.audio.audio_utils import (
|
22 |
+
TTSSpectrogram, TTSMelScale, parse_path, read_from_stored_zip, is_npy_data
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
def trim_or_pad_to_target_length(
|
27 |
+
data_1d_or_2d: np.ndarray, target_length: int
|
28 |
+
) -> np.ndarray:
|
29 |
+
assert len(data_1d_or_2d.shape) in {1, 2}
|
30 |
+
delta = data_1d_or_2d.shape[0] - target_length
|
31 |
+
if delta >= 0: # trim if being longer
|
32 |
+
data_1d_or_2d = data_1d_or_2d[: target_length]
|
33 |
+
else: # pad if being shorter
|
34 |
+
if len(data_1d_or_2d.shape) == 1:
|
35 |
+
data_1d_or_2d = np.concatenate(
|
36 |
+
[data_1d_or_2d, np.zeros(-delta)], axis=0
|
37 |
+
)
|
38 |
+
else:
|
39 |
+
data_1d_or_2d = np.concatenate(
|
40 |
+
[data_1d_or_2d, np.zeros((-delta, data_1d_or_2d.shape[1]))],
|
41 |
+
axis=0
|
42 |
+
)
|
43 |
+
return data_1d_or_2d
|
44 |
+
|
45 |
+
|
46 |
+
def extract_logmel_spectrogram(
|
47 |
+
waveform: torch.Tensor, sample_rate: int,
|
48 |
+
output_path: Optional[Path] = None, win_length: int = 1024,
|
49 |
+
hop_length: int = 256, n_fft: int = 1024,
|
50 |
+
win_fn: callable = torch.hann_window, n_mels: int = 80,
|
51 |
+
f_min: float = 0., f_max: float = 8000, eps: float = 1e-5,
|
52 |
+
overwrite: bool = False, target_length: Optional[int] = None
|
53 |
+
):
|
54 |
+
if output_path is not None and output_path.is_file() and not overwrite:
|
55 |
+
return
|
56 |
+
|
57 |
+
spectrogram_transform = TTSSpectrogram(
|
58 |
+
n_fft=n_fft, win_length=win_length, hop_length=hop_length,
|
59 |
+
window_fn=win_fn
|
60 |
+
)
|
61 |
+
mel_scale_transform = TTSMelScale(
|
62 |
+
n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
|
63 |
+
n_stft=n_fft // 2 + 1
|
64 |
+
)
|
65 |
+
spectrogram = spectrogram_transform(waveform)
|
66 |
+
mel_spec = mel_scale_transform(spectrogram)
|
67 |
+
logmel_spec = torch.clamp(mel_spec, min=eps).log()
|
68 |
+
assert len(logmel_spec.shape) == 3 and logmel_spec.shape[0] == 1
|
69 |
+
logmel_spec = logmel_spec.squeeze().t() # D x T -> T x D
|
70 |
+
if target_length is not None:
|
71 |
+
logmel_spec = trim_or_pad_to_target_length(logmel_spec, target_length)
|
72 |
+
|
73 |
+
if output_path is not None:
|
74 |
+
np.save(output_path.as_posix(), logmel_spec)
|
75 |
+
else:
|
76 |
+
return logmel_spec
|
77 |
+
|
78 |
+
|
79 |
+
def extract_pitch(
|
80 |
+
waveform: torch.Tensor, sample_rate: int,
|
81 |
+
output_path: Optional[Path] = None, hop_length: int = 256,
|
82 |
+
log_scale: bool = True, phoneme_durations: Optional[List[int]] = None
|
83 |
+
):
|
84 |
+
if output_path is not None and output_path.is_file():
|
85 |
+
return
|
86 |
+
|
87 |
+
try:
|
88 |
+
import pyworld
|
89 |
+
except ImportError:
|
90 |
+
raise ImportError("Please install PyWORLD: pip install pyworld")
|
91 |
+
|
92 |
+
_waveform = waveform.squeeze(0).double().numpy()
|
93 |
+
pitch, t = pyworld.dio(
|
94 |
+
_waveform, sample_rate, frame_period=hop_length / sample_rate * 1000
|
95 |
+
)
|
96 |
+
pitch = pyworld.stonemask(_waveform, pitch, t, sample_rate)
|
97 |
+
|
98 |
+
if phoneme_durations is not None:
|
99 |
+
pitch = trim_or_pad_to_target_length(pitch, sum(phoneme_durations))
|
100 |
+
try:
|
101 |
+
from scipy.interpolate import interp1d
|
102 |
+
except ImportError:
|
103 |
+
raise ImportError("Please install SciPy: pip install scipy")
|
104 |
+
nonzero_ids = np.where(pitch != 0)[0]
|
105 |
+
if len(nonzero_ids) == 0:
|
106 |
+
print((f"{output_path} has all empty values in the pitch contour"))
|
107 |
+
return
|
108 |
+
elif len(nonzero_ids) == 1:
|
109 |
+
print((f"{output_path} has only one non-zero values in the pitch contour"))
|
110 |
+
return
|
111 |
+
else:
|
112 |
+
interp_fn = interp1d(
|
113 |
+
nonzero_ids,
|
114 |
+
pitch[nonzero_ids],
|
115 |
+
fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]),
|
116 |
+
bounds_error=False,
|
117 |
+
)
|
118 |
+
pitch = interp_fn(np.arange(0, len(pitch)))
|
119 |
+
d_cumsum = np.cumsum(np.concatenate([np.array([0]), phoneme_durations]))
|
120 |
+
pitch = np.array(
|
121 |
+
[
|
122 |
+
np.mean(pitch[d_cumsum[i-1]: d_cumsum[i]])
|
123 |
+
for i in range(1, len(d_cumsum))
|
124 |
+
]
|
125 |
+
)
|
126 |
+
assert len(pitch) == len(phoneme_durations)
|
127 |
+
|
128 |
+
if log_scale:
|
129 |
+
pitch = np.log(pitch + 1)
|
130 |
+
|
131 |
+
if output_path is not None:
|
132 |
+
np.save(output_path.as_posix(), pitch)
|
133 |
+
else:
|
134 |
+
return pitch
|
135 |
+
|
136 |
+
|
137 |
+
def extract_energy(
|
138 |
+
waveform: torch.Tensor, output_path: Optional[Path] = None,
|
139 |
+
hop_length: int = 256, n_fft: int = 1024, log_scale: bool = True,
|
140 |
+
phoneme_durations: Optional[List[int]] = None
|
141 |
+
):
|
142 |
+
if output_path is not None and output_path.is_file():
|
143 |
+
return
|
144 |
+
|
145 |
+
assert len(waveform.shape) == 2 and waveform.shape[0] == 1
|
146 |
+
waveform = waveform.view(1, 1, waveform.shape[1])
|
147 |
+
waveform = F.pad(
|
148 |
+
waveform.unsqueeze(1), [n_fft // 2, n_fft // 2, 0, 0],
|
149 |
+
mode="reflect"
|
150 |
+
)
|
151 |
+
waveform = waveform.squeeze(1)
|
152 |
+
|
153 |
+
fourier_basis = np.fft.fft(np.eye(n_fft))
|
154 |
+
cutoff = int((n_fft / 2 + 1))
|
155 |
+
fourier_basis = np.vstack(
|
156 |
+
[np.real(fourier_basis[:cutoff, :]),
|
157 |
+
np.imag(fourier_basis[:cutoff, :])]
|
158 |
+
)
|
159 |
+
|
160 |
+
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
161 |
+
forward_transform = F.conv1d(
|
162 |
+
waveform, forward_basis, stride=hop_length, padding=0
|
163 |
+
)
|
164 |
+
|
165 |
+
real_part = forward_transform[:, :cutoff, :]
|
166 |
+
imag_part = forward_transform[:, cutoff:, :]
|
167 |
+
magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
|
168 |
+
energy = torch.norm(magnitude, dim=1).squeeze(0).numpy()
|
169 |
+
|
170 |
+
if phoneme_durations is not None:
|
171 |
+
energy = trim_or_pad_to_target_length(energy, sum(phoneme_durations))
|
172 |
+
d_cumsum = np.cumsum(np.concatenate([np.array([0]), phoneme_durations]))
|
173 |
+
energy = np.array(
|
174 |
+
[
|
175 |
+
np.mean(energy[d_cumsum[i - 1]: d_cumsum[i]])
|
176 |
+
for i in range(1, len(d_cumsum))
|
177 |
+
]
|
178 |
+
)
|
179 |
+
assert len(energy) == len(phoneme_durations)
|
180 |
+
|
181 |
+
if log_scale:
|
182 |
+
energy = np.log(energy + 1)
|
183 |
+
|
184 |
+
if output_path is not None:
|
185 |
+
np.save(output_path.as_posix(), energy)
|
186 |
+
else:
|
187 |
+
return energy
|
188 |
+
|
189 |
+
|
190 |
+
def get_global_cmvn(feature_root: Path, output_path: Optional[Path] = None):
|
191 |
+
mean_x, mean_x2, n_frames = None, None, 0
|
192 |
+
feature_paths = feature_root.glob("*.npy")
|
193 |
+
for p in tqdm(feature_paths):
|
194 |
+
with open(p, 'rb') as f:
|
195 |
+
frames = np.load(f).squeeze()
|
196 |
+
|
197 |
+
n_frames += frames.shape[0]
|
198 |
+
|
199 |
+
cur_mean_x = frames.sum(axis=0)
|
200 |
+
if mean_x is None:
|
201 |
+
mean_x = cur_mean_x
|
202 |
+
else:
|
203 |
+
mean_x += cur_mean_x
|
204 |
+
|
205 |
+
cur_mean_x2 = (frames ** 2).sum(axis=0)
|
206 |
+
if mean_x2 is None:
|
207 |
+
mean_x2 = cur_mean_x2
|
208 |
+
else:
|
209 |
+
mean_x2 += cur_mean_x2
|
210 |
+
|
211 |
+
mean_x /= n_frames
|
212 |
+
mean_x2 /= n_frames
|
213 |
+
var_x = mean_x2 - mean_x ** 2
|
214 |
+
std_x = np.sqrt(np.maximum(var_x, 1e-10))
|
215 |
+
|
216 |
+
if output_path is not None:
|
217 |
+
with open(output_path, 'wb') as f:
|
218 |
+
np.savez(f, mean=mean_x, std=std_x)
|
219 |
+
else:
|
220 |
+
return {"mean": mean_x, "std": std_x}
|
221 |
+
|
222 |
+
|
223 |
+
def ipa_phonemize(text, lang="en-us", use_g2p=False):
|
224 |
+
if use_g2p:
|
225 |
+
assert lang == "en-us", "g2pE phonemizer only works for en-us"
|
226 |
+
try:
|
227 |
+
from g2p_en import G2p
|
228 |
+
g2p = G2p()
|
229 |
+
return " ".join("|" if p == " " else p for p in g2p(text))
|
230 |
+
except ImportError:
|
231 |
+
raise ImportError(
|
232 |
+
"Please install phonemizer: pip install g2p_en"
|
233 |
+
)
|
234 |
+
else:
|
235 |
+
try:
|
236 |
+
from phonemizer import phonemize
|
237 |
+
from phonemizer.separator import Separator
|
238 |
+
return phonemize(
|
239 |
+
text, backend='espeak', language=lang,
|
240 |
+
separator=Separator(word="| ", phone=" ")
|
241 |
+
)
|
242 |
+
except ImportError:
|
243 |
+
raise ImportError(
|
244 |
+
"Please install phonemizer: pip install phonemizer"
|
245 |
+
)
|
246 |
+
|
247 |
+
|
248 |
+
@dataclass
|
249 |
+
class ForceAlignmentInfo(object):
|
250 |
+
tokens: List[str]
|
251 |
+
frame_durations: List[int]
|
252 |
+
start_sec: Optional[float]
|
253 |
+
end_sec: Optional[float]
|
254 |
+
|
255 |
+
|
256 |
+
def get_mfa_alignment_by_sample_id(
|
257 |
+
textgrid_zip_path: str, sample_id: str, sample_rate: int,
|
258 |
+
hop_length: int, silence_phones: List[str] = ("sil", "sp", "spn")
|
259 |
+
) -> ForceAlignmentInfo:
|
260 |
+
try:
|
261 |
+
import tgt
|
262 |
+
except ImportError:
|
263 |
+
raise ImportError("Please install TextGridTools: pip install tgt")
|
264 |
+
|
265 |
+
filename = f"{sample_id}.TextGrid"
|
266 |
+
out_root = Path(tempfile.gettempdir())
|
267 |
+
tgt_path = out_root / filename
|
268 |
+
with zipfile.ZipFile(textgrid_zip_path) as f_zip:
|
269 |
+
f_zip.extract(filename, path=out_root)
|
270 |
+
textgrid = tgt.io.read_textgrid(tgt_path.as_posix())
|
271 |
+
os.remove(tgt_path)
|
272 |
+
|
273 |
+
phones, frame_durations = [], []
|
274 |
+
start_sec, end_sec, end_idx = 0, 0, 0
|
275 |
+
for t in textgrid.get_tier_by_name("phones")._objects:
|
276 |
+
s, e, p = t.start_time, t.end_time, t.text
|
277 |
+
# Trim leading silences
|
278 |
+
if len(phones) == 0:
|
279 |
+
if p in silence_phones:
|
280 |
+
continue
|
281 |
+
else:
|
282 |
+
start_sec = s
|
283 |
+
phones.append(p)
|
284 |
+
if p not in silence_phones:
|
285 |
+
end_sec = e
|
286 |
+
end_idx = len(phones)
|
287 |
+
r = sample_rate / hop_length
|
288 |
+
frame_durations.append(int(np.round(e * r) - np.round(s * r)))
|
289 |
+
# Trim tailing silences
|
290 |
+
phones = phones[:end_idx]
|
291 |
+
frame_durations = frame_durations[:end_idx]
|
292 |
+
|
293 |
+
return ForceAlignmentInfo(
|
294 |
+
tokens=phones, frame_durations=frame_durations, start_sec=start_sec,
|
295 |
+
end_sec=end_sec
|
296 |
+
)
|
297 |
+
|
298 |
+
|
299 |
+
def get_mfa_alignment(
|
300 |
+
textgrid_zip_path: str, sample_ids: List[str], sample_rate: int,
|
301 |
+
hop_length: int
|
302 |
+
) -> Dict[str, ForceAlignmentInfo]:
|
303 |
+
return {
|
304 |
+
i: get_mfa_alignment_by_sample_id(
|
305 |
+
textgrid_zip_path, i, sample_rate, hop_length
|
306 |
+
) for i in tqdm(sample_ids)
|
307 |
+
}
|
308 |
+
|
309 |
+
|
310 |
+
def get_unit_alignment(
|
311 |
+
id_to_unit_tsv_path: str, sample_ids: List[str]
|
312 |
+
) -> Dict[str, ForceAlignmentInfo]:
|
313 |
+
id_to_units = {
|
314 |
+
e["id"]: e["units"] for e in load_tsv_to_dicts(id_to_unit_tsv_path)
|
315 |
+
}
|
316 |
+
id_to_units = {i: id_to_units[i].split() for i in sample_ids}
|
317 |
+
id_to_units_collapsed = {
|
318 |
+
i: [uu for uu, _ in groupby(u)] for i, u in id_to_units.items()
|
319 |
+
}
|
320 |
+
id_to_durations = {
|
321 |
+
i: [len(list(g)) for _, g in groupby(u)] for i, u in id_to_units.items()
|
322 |
+
}
|
323 |
+
|
324 |
+
return {
|
325 |
+
i: ForceAlignmentInfo(
|
326 |
+
tokens=id_to_units_collapsed[i], frame_durations=id_to_durations[i],
|
327 |
+
start_sec=None, end_sec=None
|
328 |
+
)
|
329 |
+
for i in sample_ids
|
330 |
+
}
|
331 |
+
|
332 |
+
|
333 |
+
def get_feature_value_min_max(feature_paths: List[str]):
|
334 |
+
v_min, v_max = 1e-8, -1e-8
|
335 |
+
for p in tqdm(feature_paths):
|
336 |
+
_path, slice_ptr = parse_path(p)
|
337 |
+
assert len(slice_ptr) == 2
|
338 |
+
byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
|
339 |
+
assert is_npy_data(byte_data)
|
340 |
+
path_or_fp = io.BytesIO(byte_data)
|
341 |
+
features = np.load(path_or_fp).squeeze()
|
342 |
+
v_min = min(v_min, features.min().item())
|
343 |
+
v_max = max(v_max, features.max().item())
|
344 |
+
return v_min, v_max
|
fairseq/examples/speech_synthesis/docs/common_voice_example.md
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[[Back]](..)
|
2 |
+
|
3 |
+
# Common Voice
|
4 |
+
|
5 |
+
[Common Voice](https://commonvoice.mozilla.org/en/datasets) is a public domain speech corpus with 11.2K hours of read
|
6 |
+
speech in 76 languages (the latest version 7.0). We provide examples for building
|
7 |
+
[Transformer](https://arxiv.org/abs/1809.08895) models on this dataset.
|
8 |
+
|
9 |
+
|
10 |
+
## Data preparation
|
11 |
+
[Download](https://commonvoice.mozilla.org/en/datasets) and unpack Common Voice v4 to a path `${DATA_ROOT}/${LANG_ID}`.
|
12 |
+
Create splits and generate audio manifests with
|
13 |
+
```bash
|
14 |
+
python -m examples.speech_synthesis.preprocessing.get_common_voice_audio_manifest \
|
15 |
+
--data-root ${DATA_ROOT} \
|
16 |
+
--lang ${LANG_ID} \
|
17 |
+
--output-manifest-root ${AUDIO_MANIFEST_ROOT} --convert-to-wav
|
18 |
+
```
|
19 |
+
|
20 |
+
To denoise audio and trim leading/trailing silence using signal processing based VAD, run
|
21 |
+
```bash
|
22 |
+
for SPLIT in dev test train; do
|
23 |
+
python -m examples.speech_synthesis.preprocessing.denoise_and_vad_audio \
|
24 |
+
--audio-manifest ${AUDIO_MANIFEST_ROOT}/${SPLIT}.audio.tsv \
|
25 |
+
--output-dir ${PROCESSED_DATA_ROOT} \
|
26 |
+
--denoise --vad --vad-agg-level 2
|
27 |
+
done
|
28 |
+
```
|
29 |
+
|
30 |
+
which generates a new audio TSV manifest under `${PROCESSED_DATA_ROOT}` with updated path to the processed audio and
|
31 |
+
a new column for SNR.
|
32 |
+
|
33 |
+
To do filtering by CER, follow the [Automatic Evaluation](../docs/ljspeech_example.md#automatic-evaluation) section to
|
34 |
+
run ASR model (add `--eval-target` to `get_eval_manifest` for evaluation on the reference audio; add `--err-unit char`
|
35 |
+
to `eval_asr` to compute CER instead of WER). The example-level CER is saved to
|
36 |
+
`${EVAL_OUTPUT_ROOT}/uer_cer.${SPLIT}.tsv`.
|
37 |
+
|
38 |
+
Then, extract log-Mel spectrograms, generate feature manifest and create data configuration YAML with
|
39 |
+
```bash
|
40 |
+
python -m examples.speech_synthesis.preprocessing.get_feature_manifest \
|
41 |
+
--audio-manifest-root ${AUDIO_MANIFEST_ROOT} \
|
42 |
+
--output-root ${FEATURE_MANIFEST_ROOT} \
|
43 |
+
--ipa-vocab --lang ${LANG_ID} \
|
44 |
+
--snr-threshold 15 \
|
45 |
+
--cer-threshold 0.1 --cer-tsv-path ${EVAL_OUTPUT_ROOT}/uer_cer.${SPLIT}.tsv
|
46 |
+
```
|
47 |
+
where we use phoneme inputs (`--ipa-vocab`) as example. For sample filtering, we set the SNR and CER threshold
|
48 |
+
to 15 and 10%, respectively.
|
49 |
+
|
50 |
+
|
51 |
+
## Training
|
52 |
+
(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#transformer).)
|
53 |
+
|
54 |
+
|
55 |
+
## Inference
|
56 |
+
(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#inference).)
|
57 |
+
|
58 |
+
## Automatic Evaluation
|
59 |
+
(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#automatic-evaluation).)
|
60 |
+
|
61 |
+
## Results
|
62 |
+
|
63 |
+
| Language | Speakers | --arch | Params | Test MCD | Model |
|
64 |
+
|---|---|---|---|---|---|
|
65 |
+
| English | 200 | tts_transformer | 54M | 3.8 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2/cv4_en200_transformer_phn.tar) |
|
66 |
+
|
67 |
+
[[Back]](..)
|
fairseq/examples/speech_synthesis/docs/ljspeech_example.md
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[[Back]](..)
|
2 |
+
|
3 |
+
# LJSpeech
|
4 |
+
|
5 |
+
[LJSpeech](https://keithito.com/LJ-Speech-Dataset) is a public domain TTS
|
6 |
+
corpus with around 24 hours of English speech sampled at 22.05kHz. We provide examples for building
|
7 |
+
[Transformer](https://arxiv.org/abs/1809.08895) and [FastSpeech 2](https://arxiv.org/abs/2006.04558)
|
8 |
+
models on this dataset.
|
9 |
+
|
10 |
+
|
11 |
+
## Data preparation
|
12 |
+
|
13 |
+
Download data, create splits and generate audio manifests with
|
14 |
+
```bash
|
15 |
+
python -m examples.speech_synthesis.preprocessing.get_ljspeech_audio_manifest \
|
16 |
+
--output-data-root ${AUDIO_DATA_ROOT} \
|
17 |
+
--output-manifest-root ${AUDIO_MANIFEST_ROOT}
|
18 |
+
```
|
19 |
+
|
20 |
+
Then, extract log-Mel spectrograms, generate feature manifest and create data configuration YAML with
|
21 |
+
```bash
|
22 |
+
python -m examples.speech_synthesis.preprocessing.get_feature_manifest \
|
23 |
+
--audio-manifest-root ${AUDIO_MANIFEST_ROOT} \
|
24 |
+
--output-root ${FEATURE_MANIFEST_ROOT} \
|
25 |
+
--ipa-vocab --use-g2p
|
26 |
+
```
|
27 |
+
where we use phoneme inputs (`--ipa-vocab --use-g2p`) as example.
|
28 |
+
|
29 |
+
FastSpeech 2 additionally requires frame durations, pitch and energy as auxiliary training targets.
|
30 |
+
Add `--add-fastspeech-targets` to include these fields in the feature manifests. We get frame durations either from
|
31 |
+
phoneme-level force-alignment or frame-level pseudo-text unit sequence. They should be pre-computed and specified via:
|
32 |
+
- `--textgrid-zip ${TEXT_GRID_ZIP_PATH}` for a ZIP file, inside which there is one
|
33 |
+
[TextGrid](https://www.fon.hum.uva.nl/praat/manual/TextGrid.html) file per sample to provide force-alignment info.
|
34 |
+
- `--id-to-units-tsv ${ID_TO_UNIT_TSV}` for a TSV file, where there are 2 columns for sample ID and
|
35 |
+
space-delimited pseudo-text unit sequence, respectively.
|
36 |
+
|
37 |
+
For your convenience, we provide pre-computed
|
38 |
+
[force-alignment](https://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_mfa.zip) from
|
39 |
+
[Montreal Forced Aligner](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) and
|
40 |
+
[pseudo-text units](s3://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_hubert.tsv) from
|
41 |
+
[HuBERT](https://github.com/pytorch/fairseq/tree/main/examples/hubert). You can also generate them by yourself using
|
42 |
+
a different software or model.
|
43 |
+
|
44 |
+
|
45 |
+
## Training
|
46 |
+
#### Transformer
|
47 |
+
```bash
|
48 |
+
fairseq-train ${FEATURE_MANIFEST_ROOT} --save-dir ${SAVE_DIR} \
|
49 |
+
--config-yaml config.yaml --train-subset train --valid-subset dev \
|
50 |
+
--num-workers 4 --max-tokens 30000 --max-update 200000 \
|
51 |
+
--task text_to_speech --criterion tacotron2 --arch tts_transformer \
|
52 |
+
--clip-norm 5.0 --n-frames-per-step 4 --bce-pos-weight 5.0 \
|
53 |
+
--dropout 0.1 --attention-dropout 0.1 --activation-dropout 0.1 \
|
54 |
+
--encoder-normalize-before --decoder-normalize-before \
|
55 |
+
--optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
|
56 |
+
--seed 1 --update-freq 8 --eval-inference --best-checkpoint-metric mcd_loss
|
57 |
+
```
|
58 |
+
where `SAVE_DIR` is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to
|
59 |
+
update it accordingly when using more than 1 GPU.
|
60 |
+
|
61 |
+
#### FastSpeech2
|
62 |
+
```bash
|
63 |
+
fairseq-train ${FEATURE_MANIFEST_ROOT} --save-dir ${SAVE_DIR} \
|
64 |
+
--config-yaml config.yaml --train-subset train --valid-subset dev \
|
65 |
+
--num-workers 4 --max-sentences 6 --max-update 200000 \
|
66 |
+
--task text_to_speech --criterion fastspeech2 --arch fastspeech2 \
|
67 |
+
--clip-norm 5.0 --n-frames-per-step 1 \
|
68 |
+
--dropout 0.1 --attention-dropout 0.1 \
|
69 |
+
--optimizer adam --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
|
70 |
+
--seed 1 --update-freq 8 --eval-inference --best-checkpoint-metric mcd_loss
|
71 |
+
```
|
72 |
+
|
73 |
+
|
74 |
+
## Inference
|
75 |
+
Average the last 5 checkpoints, generate the test split spectrogram and waveform using the default Griffin-Lim vocoder:
|
76 |
+
```bash
|
77 |
+
SPLIT=test
|
78 |
+
CHECKPOINT_NAME=avg_last_5
|
79 |
+
CHECKPOINT_PATH=${SAVE_DIR}/checkpoint_${CHECKPOINT_NAME}.pt
|
80 |
+
python scripts/average_checkpoints.py --inputs ${SAVE_DIR} \
|
81 |
+
--num-epoch-checkpoints 5 \
|
82 |
+
--output ${CHECKPOINT_PATH}
|
83 |
+
|
84 |
+
python -m examples.speech_synthesis.generate_waveform ${FEATURE_MANIFEST_ROOT} \
|
85 |
+
--config-yaml config.yaml --gen-subset ${SPLIT} --task text_to_speech \
|
86 |
+
--path ${CHECKPOINT_PATH} --max-tokens 50000 --spec-bwd-max-iter 32 \
|
87 |
+
--dump-waveforms
|
88 |
+
```
|
89 |
+
which dumps files (waveform, feature, attention plot, etc.) to `${SAVE_DIR}/generate-${CHECKPOINT_NAME}-${SPLIT}`. To
|
90 |
+
re-synthesize target waveforms for automatic evaluation, add `--dump-target`.
|
91 |
+
|
92 |
+
## Automatic Evaluation
|
93 |
+
To start with, generate the manifest for synthetic speech, which will be taken as inputs by evaluation scripts.
|
94 |
+
```bash
|
95 |
+
python -m examples.speech_synthesis.evaluation.get_eval_manifest \
|
96 |
+
--generation-root ${SAVE_DIR}/generate-${CHECKPOINT_NAME}-${SPLIT} \
|
97 |
+
--audio-manifest ${AUDIO_MANIFEST_ROOT}/${SPLIT}.audio.tsv \
|
98 |
+
--output-path ${EVAL_OUTPUT_ROOT}/eval.tsv \
|
99 |
+
--vocoder griffin_lim --sample-rate 22050 --audio-format flac \
|
100 |
+
--use-resynthesized-target
|
101 |
+
```
|
102 |
+
Speech recognition (ASR) models usually operate at lower sample rates (e.g. 16kHz). For the WER/CER metric,
|
103 |
+
you may need to resample the audios accordingly --- add `--output-sample-rate 16000` for `generate_waveform.py` and
|
104 |
+
use `--sample-rate 16000` for `get_eval_manifest.py`.
|
105 |
+
|
106 |
+
|
107 |
+
#### WER/CER metric
|
108 |
+
We use wav2vec 2.0 ASR model as example. [Download](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec)
|
109 |
+
the model checkpoint and dictionary, then compute WER/CER with
|
110 |
+
```bash
|
111 |
+
python -m examples.speech_synthesis.evaluation.eval_asr \
|
112 |
+
--audio-header syn --text-header text --err-unit char --split ${SPLIT} \
|
113 |
+
--w2v-ckpt ${WAV2VEC2_CHECKPOINT_PATH} --w2v-dict-dir ${WAV2VEC2_DICT_DIR} \
|
114 |
+
--raw-manifest ${EVAL_OUTPUT_ROOT}/eval_16khz.tsv --asr-dir ${EVAL_OUTPUT_ROOT}/asr
|
115 |
+
```
|
116 |
+
|
117 |
+
#### MCD/MSD metric
|
118 |
+
```bash
|
119 |
+
python -m examples.speech_synthesis.evaluation.eval_sp \
|
120 |
+
${EVAL_OUTPUT_ROOT}/eval.tsv --mcd --msd
|
121 |
+
```
|
122 |
+
|
123 |
+
#### F0 metrics
|
124 |
+
```bash
|
125 |
+
python -m examples.speech_synthesis.evaluation.eval_f0 \
|
126 |
+
${EVAL_OUTPUT_ROOT}/eval.tsv --gpe --vde --ffe
|
127 |
+
```
|
128 |
+
|
129 |
+
|
130 |
+
## Results
|
131 |
+
|
132 |
+
| --arch | Params | Test MCD | Model |
|
133 |
+
|---|---|---|---|
|
134 |
+
| tts_transformer | 54M | 3.8 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_transformer_phn.tar) |
|
135 |
+
| fastspeech2 | 41M | 3.8 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_fastspeech2_phn.tar) |
|
136 |
+
|
137 |
+
[[Back]](..)
|