File size: 4,774 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import numpy as np
import torch
class SegmentStreamingE2E(object):
"""SegmentStreamingE2E constructor.
:param E2E e2e: E2E ASR object
:param recog_args: arguments for "recognize" method of E2E
"""
def __init__(self, e2e, recog_args, rnnlm=None):
self._e2e = e2e
self._recog_args = recog_args
self._char_list = e2e.char_list
self._rnnlm = rnnlm
self._e2e.eval()
self._blank_idx_in_char_list = -1
for idx in range(len(self._char_list)):
if self._char_list[idx] == self._e2e.blank:
self._blank_idx_in_char_list = idx
break
self._subsampling_factor = np.prod(e2e.subsample)
self._activates = 0
self._blank_dur = 0
self._previous_input = []
self._previous_encoder_recurrent_state = None
self._encoder_states = []
self._ctc_posteriors = []
assert (
self._recog_args.batchsize <= 1
), "SegmentStreamingE2E works only with batch size <= 1"
assert (
"b" not in self._e2e.etype
), "SegmentStreamingE2E works only with uni-directional encoders"
def accept_input(self, x):
"""Call this method each time a new batch of input is available."""
self._previous_input.extend(x)
h, ilen = self._e2e.subsample_frames(x)
# Run encoder and apply greedy search on CTC softmax output
h, _, self._previous_encoder_recurrent_state = self._e2e.enc(
h.unsqueeze(0), ilen, self._previous_encoder_recurrent_state
)
z = self._e2e.ctc.argmax(h).squeeze(0)
if self._activates == 0 and z[0] != self._blank_idx_in_char_list:
self._activates = 1
# Rerun encoder with zero state at onset of detection
tail_len = self._subsampling_factor * (
self._recog_args.streaming_onset_margin + 1
)
h, ilen = self._e2e.subsample_frames(
np.reshape(
self._previous_input[-tail_len:], [-1, len(self._previous_input[0])]
)
)
h, _, self._previous_encoder_recurrent_state = self._e2e.enc(
h.unsqueeze(0), ilen, None
)
hyp = None
if self._activates == 1:
self._encoder_states.extend(h.squeeze(0))
self._ctc_posteriors.extend(self._e2e.ctc.log_softmax(h).squeeze(0))
if z[0] == self._blank_idx_in_char_list:
self._blank_dur += 1
else:
self._blank_dur = 0
if self._blank_dur >= self._recog_args.streaming_min_blank_dur:
seg_len = (
len(self._encoder_states)
- self._blank_dur
+ self._recog_args.streaming_offset_margin
)
if seg_len > 0:
# Run decoder with a detected segment
h = torch.cat(self._encoder_states[:seg_len], dim=0).view(
-1, self._encoder_states[0].size(0)
)
if self._recog_args.ctc_weight > 0.0:
lpz = torch.cat(self._ctc_posteriors[:seg_len], dim=0).view(
-1, self._ctc_posteriors[0].size(0)
)
if self._recog_args.batchsize > 0:
lpz = lpz.unsqueeze(0)
normalize_score = False
else:
lpz = None
normalize_score = True
if self._recog_args.batchsize == 0:
hyp = self._e2e.dec.recognize_beam(
h, lpz, self._recog_args, self._char_list, self._rnnlm
)
else:
hlens = torch.tensor([h.shape[0]])
hyp = self._e2e.dec.recognize_beam_batch(
h.unsqueeze(0),
hlens,
lpz,
self._recog_args,
self._char_list,
self._rnnlm,
normalize_score=normalize_score,
)[0]
self._activates = 0
self._blank_dur = 0
tail_len = (
self._subsampling_factor
* self._recog_args.streaming_onset_margin
)
self._previous_input = self._previous_input[-tail_len:]
self._encoder_states = []
self._ctc_posteriors = []
return hyp
|