File size: 4,774 Bytes
ad16788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import numpy as np
import torch


class SegmentStreamingE2E(object):
    """SegmentStreamingE2E constructor.

    :param E2E e2e: E2E ASR object
    :param recog_args: arguments for "recognize" method of E2E
    """

    def __init__(self, e2e, recog_args, rnnlm=None):
        self._e2e = e2e
        self._recog_args = recog_args
        self._char_list = e2e.char_list
        self._rnnlm = rnnlm

        self._e2e.eval()

        self._blank_idx_in_char_list = -1
        for idx in range(len(self._char_list)):
            if self._char_list[idx] == self._e2e.blank:
                self._blank_idx_in_char_list = idx
                break

        self._subsampling_factor = np.prod(e2e.subsample)
        self._activates = 0
        self._blank_dur = 0

        self._previous_input = []
        self._previous_encoder_recurrent_state = None
        self._encoder_states = []
        self._ctc_posteriors = []

        assert (
            self._recog_args.batchsize <= 1
        ), "SegmentStreamingE2E works only with batch size <= 1"
        assert (
            "b" not in self._e2e.etype
        ), "SegmentStreamingE2E works only with uni-directional encoders"

    def accept_input(self, x):
        """Call this method each time a new batch of input is available."""

        self._previous_input.extend(x)
        h, ilen = self._e2e.subsample_frames(x)

        # Run encoder and apply greedy search on CTC softmax output
        h, _, self._previous_encoder_recurrent_state = self._e2e.enc(
            h.unsqueeze(0), ilen, self._previous_encoder_recurrent_state
        )
        z = self._e2e.ctc.argmax(h).squeeze(0)

        if self._activates == 0 and z[0] != self._blank_idx_in_char_list:
            self._activates = 1

            # Rerun encoder with zero state at onset of detection
            tail_len = self._subsampling_factor * (
                self._recog_args.streaming_onset_margin + 1
            )
            h, ilen = self._e2e.subsample_frames(
                np.reshape(
                    self._previous_input[-tail_len:], [-1, len(self._previous_input[0])]
                )
            )
            h, _, self._previous_encoder_recurrent_state = self._e2e.enc(
                h.unsqueeze(0), ilen, None
            )

        hyp = None
        if self._activates == 1:
            self._encoder_states.extend(h.squeeze(0))
            self._ctc_posteriors.extend(self._e2e.ctc.log_softmax(h).squeeze(0))

            if z[0] == self._blank_idx_in_char_list:
                self._blank_dur += 1
            else:
                self._blank_dur = 0

            if self._blank_dur >= self._recog_args.streaming_min_blank_dur:
                seg_len = (
                    len(self._encoder_states)
                    - self._blank_dur
                    + self._recog_args.streaming_offset_margin
                )
                if seg_len > 0:
                    # Run decoder with a detected segment
                    h = torch.cat(self._encoder_states[:seg_len], dim=0).view(
                        -1, self._encoder_states[0].size(0)
                    )
                    if self._recog_args.ctc_weight > 0.0:
                        lpz = torch.cat(self._ctc_posteriors[:seg_len], dim=0).view(
                            -1, self._ctc_posteriors[0].size(0)
                        )
                        if self._recog_args.batchsize > 0:
                            lpz = lpz.unsqueeze(0)
                        normalize_score = False
                    else:
                        lpz = None
                        normalize_score = True

                    if self._recog_args.batchsize == 0:
                        hyp = self._e2e.dec.recognize_beam(
                            h, lpz, self._recog_args, self._char_list, self._rnnlm
                        )
                    else:
                        hlens = torch.tensor([h.shape[0]])
                        hyp = self._e2e.dec.recognize_beam_batch(
                            h.unsqueeze(0),
                            hlens,
                            lpz,
                            self._recog_args,
                            self._char_list,
                            self._rnnlm,
                            normalize_score=normalize_score,
                        )[0]

                    self._activates = 0
                    self._blank_dur = 0

                    tail_len = (
                        self._subsampling_factor
                        * self._recog_args.streaming_onset_margin
                    )
                    self._previous_input = self._previous_input[-tail_len:]
                    self._encoder_states = []
                    self._ctc_posteriors = []

        return hyp