|
import torch |
|
|
|
|
|
|
|
|
|
class WindowStreamingE2E(object): |
|
"""WindowStreamingE2E constructor. |
|
|
|
:param E2E e2e: E2E ASR object |
|
:param recog_args: arguments for "recognize" method of E2E |
|
""" |
|
|
|
def __init__(self, e2e, recog_args, rnnlm=None): |
|
self._e2e = e2e |
|
self._recog_args = recog_args |
|
self._char_list = e2e.char_list |
|
self._rnnlm = rnnlm |
|
|
|
self._e2e.eval() |
|
|
|
self._offset = 0 |
|
self._previous_encoder_recurrent_state = None |
|
self._encoder_states = [] |
|
self._ctc_posteriors = [] |
|
self._last_recognition = None |
|
|
|
assert ( |
|
self._recog_args.ctc_weight > 0.0 |
|
), "WindowStreamingE2E works only with combined CTC and attention decoders." |
|
|
|
def accept_input(self, x): |
|
"""Call this method each time a new batch of input is available.""" |
|
|
|
h, ilen = self._e2e.subsample_frames(x) |
|
|
|
|
|
h, _, self._previous_encoder_recurrent_state = self._e2e.enc( |
|
h.unsqueeze(0), ilen, self._previous_encoder_recurrent_state |
|
) |
|
self._encoder_states.append(h.squeeze(0)) |
|
|
|
|
|
self._ctc_posteriors.append(self._e2e.ctc.log_softmax(h).squeeze(0)) |
|
|
|
def _input_window_for_decoder(self, use_all=False): |
|
if use_all: |
|
return ( |
|
torch.cat(self._encoder_states, dim=0), |
|
torch.cat(self._ctc_posteriors, dim=0), |
|
) |
|
|
|
def select_unprocessed_windows(window_tensors): |
|
last_offset = self._offset |
|
offset_traversed = 0 |
|
selected_windows = [] |
|
for es in window_tensors: |
|
if offset_traversed > last_offset: |
|
selected_windows.append(es) |
|
continue |
|
offset_traversed += es.size(1) |
|
return torch.cat(selected_windows, dim=0) |
|
|
|
return ( |
|
select_unprocessed_windows(self._encoder_states), |
|
select_unprocessed_windows(self._ctc_posteriors), |
|
) |
|
|
|
def decode_with_attention_offline(self): |
|
"""Run the attention decoder offline. |
|
|
|
Works even if the previous layers (encoder and CTC decoder) were |
|
being run in the online mode. |
|
This method should be run after all the audio has been consumed. |
|
This is used mostly to compare the results between offline |
|
and online implementation of the previous layers. |
|
""" |
|
h, lpz = self._input_window_for_decoder(use_all=True) |
|
|
|
return self._e2e.dec.recognize_beam( |
|
h, lpz, self._recog_args, self._char_list, self._rnnlm |
|
) |
|
|