File size: 3,235 Bytes
affcd23 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
from typing import List, NamedTuple
import torch
from pyctcdecode import build_ctcdecoder
from hw_asr.base.base_text_encoder import BaseTextEncoder
from .char_text_encoder import CharTextEncoder
from collections import defaultdict
class Hypothesis(NamedTuple):
text: str
prob: float
class CTCCharTextEncoder(CharTextEncoder):
EMPTY_TOK = "^"
def __init__(self, alphabet: List[str] = None, kenlm_model_path: str = None, unigrams_path: str = None):
super().__init__(alphabet)
vocab = [self.EMPTY_TOK] + list(self.alphabet)
self.ind2char = dict(enumerate(vocab))
self.char2ind = {v: k for k, v in self.ind2char.items()}
if kenlm_model_path is not None:
with open(unigrams_path) as f:
unigrams = [line.strip() for line in f.readlines()]
self.decoder = build_ctcdecoder(labels=[""] + self.alphabet, kenlm_model_path=kenlm_model_path, unigrams=unigrams)
def ctc_decode(self, inds: List[int]) -> str:
# TODO: your code here
result = []
last_char = self.EMPTY_TOK
for ind in inds:
cur_char = self.ind2char[ind]
if cur_char != self.EMPTY_TOK and last_char != cur_char:
result.append(cur_char)
last_char = cur_char
return ''.join(result)
def ctc_beam_search(self, probs: torch.tensor, beam_size: int) -> str:
"""
Performs beam search and returns a list of pairs (hypothesis, hypothesis probability).
"""
assert len(probs.shape) == 2
char_length, voc_size = probs.shape
assert voc_size == len(self.ind2char)
hypos: List[Hypothesis] = []
# TODO: your code here
def extend_and_merge(frame, state):
new_state = defaultdict(float)
for next_char_index, next_char_proba in enumerate(frame):
for (pref, last_char), pref_proba in state.items():
next_char = self.ind2char[next_char_index]
if next_char == last_char:
new_pref = pref
else:
if next_char != self.EMPTY_TOK:
new_pref = pref + next_char
else:
new_pref = pref
last_char = next_char
new_state[(new_pref, last_char)] += pref_proba * next_char_proba
return new_state
def truncate(state, beam_size):
state_list = list(state.items())
state_list.sort(key=lambda x: -x[1])
return dict(state_list[:beam_size])
state = {('', self.EMPTY_TOK): 1.0}
for frame in probs:
state = extend_and_merge(frame, state)
state = truncate(state, beam_size)
state_list = list(state.items())
state_list.sort(key=lambda x: -x[1])
# for state in state_list:
# hypos.append(Hypothesis(state[0][0], state[1]))
return state_list[0][0][0]
def ctc_lm_beam_search(self, logits: torch.tensor) -> str:
assert self.decoder is not None
return self.decoder.decode(logits, beam_width=500).lower() |