File size: 6,862 Bytes
0b32ad6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
"""
The beam search decoder of flashlight
Authors:
* Heng-Jui Chang 2022
"""
import itertools as it
import logging
import math
from typing import Iterable, List
import torch
from s3prl.util.download import _urls_to_filepaths
TOKEN_URL = "https://huggingface.co/datasets/s3prl/flashlight/raw/main/lexicon/librispeech_char_tokens.txt"
LEXICON_URL_1 = "https://dl.fbaipublicfiles.com/fairseq/wav2vec/librispeech_lexicon.lst"
LEXICON_URL_2 = "https://huggingface.co/datasets/s3prl/flashlight/raw/main/lexicon/librispeech_lexicon.lst"
LM_URL_1 = "https://www.openslr.org/resources/11/4-gram.arpa.gz"
LM_URL_2 = (
"https://huggingface.co/datasets/s3prl/flashlight/resolve/main/lm/4-gram.arpa.gz"
)
logger = logging.getLogger(__name__)
__all__ = ["BeamDecoder"]
class BeamDecoder(object):
"""Beam decoder powered by flashlight.
Args:
token (str, optional): Path to dictionary file. Defaults to "".
lexicon (str, optional): Path to lexicon file. Defaults to "".
lm (str, optional): Path to KenLM file. Defaults to "".
nbest (int, optional): Returns nbest hypotheses. Defaults to 1.
beam (int, optional): Beam size. Defaults to 5.
beam_size_token (int, optional): Token beam size. Defaults to -1.
beam_threshold (float, optional): Beam search log prob threshold. Defaults to 25.0.
lm_weight (float, optional): language model weight. Defaults to 2.0.
word_score (float, optional): score for words appearance in the transcription. Defaults to -1.0.
unk_score (float, optional): score for unknown word appearance in the transcription. Defaults to -math.inf.
sil_score (float, optional): score for silence appearance in the transcription. Defaults to 0.0.
"""
def __init__(
self,
token: str = "",
lexicon: str = "",
lm: str = "",
nbest: int = 1,
beam: int = 5,
beam_size_token: int = -1,
beam_threshold: float = 25.0,
lm_weight: float = 2.0,
word_score: float = -1.0,
unk_score: float = -math.inf,
sil_score: float = 0.0,
):
try:
from flashlight.lib.text.decoder import (
CriterionType,
KenLM,
LexiconDecoder,
LexiconDecoderOptions,
SmearingMode,
Trie,
)
from flashlight.lib.text.dictionary import (
Dictionary,
create_word_dict,
load_words,
)
except ImportError:
logger.error(
f"Please install Flashlight Text from https://github.com/flashlight/text to enable {__class__.__name__}"
)
raise
if token == "":
token = _urls_to_filepaths(TOKEN_URL)
if lexicon == "":
# Try LEXICON_URL_2 if LEXICON_URL_1 did not work.
lexicon = _urls_to_filepaths(LEXICON_URL_1)
if lm == "":
# Try LM_URL_2 if LM_URL_1 did not work.
lm = _urls_to_filepaths(LM_URL_1)
self.nbest = nbest
self.token_dict = Dictionary(token)
self.lexicon = load_words(lexicon)
self.word_dict = create_word_dict(self.lexicon)
self.lm = KenLM(lm, self.word_dict)
self.sil_idx = self.token_dict.get_index("|")
self.unk_idx = self.word_dict.get_index("<unk>")
self.trie = Trie(self.token_dict.index_size(), self.sil_idx)
start_state = self.lm.start(False)
for word, spellings in self.lexicon.items():
usr_idx = self.word_dict.get_index(word)
_, score = self.lm.score(start_state, usr_idx)
for spelling in spellings:
spelling_idxs = [self.token_dict.get_index(tok) for tok in spelling]
self.trie.insert(spelling_idxs, usr_idx, score)
self.trie.smear(SmearingMode.MAX)
if beam_size_token == -1:
beam_size_token = self.token_dict.index_size()
self.options = LexiconDecoderOptions(
beam_size=beam,
beam_size_token=beam_size_token,
beam_threshold=beam_threshold,
lm_weight=lm_weight,
word_score=word_score,
unk_score=unk_score,
sil_score=sil_score,
log_add=False,
criterion_type=CriterionType.CTC,
)
self.blank_idx = self.token_dict.get_index("#")
self.decoder = LexiconDecoder(
self.options,
self.trie,
self.lm,
self.sil_idx,
self.blank_idx,
self.unk_idx,
[],
False,
)
def get_tokens(self, idxs: Iterable) -> torch.LongTensor:
"""Normalize tokens by handling CTC blank, ASG replabels, etc.
Args:
idxs (Iterable): Token ID list output by self.decoder
Returns:
torch.LongTensor: Token ID list after normalization.
"""
idxs = (g[0] for g in it.groupby(idxs))
idxs = filter(lambda x: x != self.blank_idx, idxs)
return torch.LongTensor(list(idxs))
def get_timesteps(self, token_idxs: List[int]) -> List[int]:
"""Returns frame numbers corresponding to every non-blank token.
Args:
token_idxs (List[int]): IDs of decoded tokens.
Returns:
List[int]: Frame numbers corresponding to every non-blank token.
"""
timesteps = []
for i, token_idx in enumerate(token_idxs):
if token_idx == self.blank_idx:
continue
if i == 0 or token_idx != token_idxs[i - 1]:
timesteps.append(i)
return timesteps
def decode(self, emissions: torch.Tensor) -> List[List[dict]]:
"""Decode sequence.
Args:
emissions (torch.Tensor): Emission probabilities (in log scale).
Returns:
List[List[dict]]: Decoded hypotheses.
"""
emissions = emissions.float().contiguous().cpu()
B, T, N = emissions.size()
hyps = []
for b in range(B):
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
results = self.decoder.decode(emissions_ptr, T, N)
nbest_results = results[: self.nbest]
hyps.append(
[
dict(
tokens=self.get_tokens(result.tokens),
score=result.score,
timesteps=self.get_timesteps(result.tokens),
words=[
self.word_dict.get_entry(x) for x in result.words if x >= 0
],
)
for result in nbest_results
]
)
return hyps
|