File size: 3,235 Bytes
affcd23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from typing import List, NamedTuple

import torch
from pyctcdecode import build_ctcdecoder


from hw_asr.base.base_text_encoder import BaseTextEncoder
from .char_text_encoder import CharTextEncoder
from collections import defaultdict


class Hypothesis(NamedTuple):
    text: str
    prob: float


class CTCCharTextEncoder(CharTextEncoder):
    EMPTY_TOK = "^"

    def __init__(self, alphabet: List[str] = None, kenlm_model_path: str = None, unigrams_path: str = None):
        super().__init__(alphabet)
        vocab = [self.EMPTY_TOK] + list(self.alphabet)
        self.ind2char = dict(enumerate(vocab))
        self.char2ind = {v: k for k, v in self.ind2char.items()}
        if kenlm_model_path is not None:
            with open(unigrams_path) as f:
                unigrams = [line.strip() for line in f.readlines()]
            self.decoder = build_ctcdecoder(labels=[""] + self.alphabet, kenlm_model_path=kenlm_model_path, unigrams=unigrams)

    def ctc_decode(self, inds: List[int]) -> str:
        # TODO: your code here
        result = []
        last_char = self.EMPTY_TOK
        for ind in inds:
            cur_char = self.ind2char[ind]
            if cur_char != self.EMPTY_TOK and last_char != cur_char:
                result.append(cur_char)
            last_char = cur_char
        return ''.join(result)

    def ctc_beam_search(self, probs: torch.tensor, beam_size: int) -> str:
        """
            Performs beam search and returns a list of pairs (hypothesis, hypothesis probability).
        """
        assert len(probs.shape) == 2
        char_length, voc_size = probs.shape
        assert voc_size == len(self.ind2char)
        hypos: List[Hypothesis] = []
        # TODO: your code here

        def extend_and_merge(frame, state):
            new_state = defaultdict(float)
            for next_char_index, next_char_proba in enumerate(frame):
                for (pref, last_char), pref_proba in state.items():
                    next_char = self.ind2char[next_char_index]
                    if next_char == last_char:
                        new_pref = pref
                    else:
                        if next_char != self.EMPTY_TOK:
                            new_pref = pref + next_char
                        else:
                            new_pref = pref
                        last_char = next_char
                    new_state[(new_pref, last_char)] += pref_proba * next_char_proba
            return new_state

        def truncate(state, beam_size):
            state_list = list(state.items())
            state_list.sort(key=lambda x: -x[1])
            return dict(state_list[:beam_size])

        state = {('', self.EMPTY_TOK): 1.0}
        for frame in probs:
            state = extend_and_merge(frame, state)
            state = truncate(state, beam_size)
        state_list = list(state.items())
        state_list.sort(key=lambda x: -x[1])
        
        # for state in state_list:
        #     hypos.append(Hypothesis(state[0][0], state[1]))
        
        return state_list[0][0][0]
    
    def ctc_lm_beam_search(self, logits: torch.tensor) -> str:
        assert self.decoder is not None
        return self.decoder.decode(logits, beam_width=500).lower()