Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- encoding: utf-8 -*- | |
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. | |
# MIT License (https://opensource.org/licenses/MIT) | |
import torch | |
from typing import List, Optional, Tuple | |
from funasr_detach.register import tables | |
from funasr_detach.models.specaug.specaug import SpecAug | |
from funasr_detach.models.transducer.beam_search_transducer import Hypothesis | |
class RNNTDecoder(torch.nn.Module): | |
"""RNN decoder module. | |
Args: | |
vocab_size: Vocabulary size. | |
embed_size: Embedding size. | |
hidden_size: Hidden size.. | |
rnn_type: Decoder layers type. | |
num_layers: Number of decoder layers. | |
dropout_rate: Dropout rate for decoder layers. | |
embed_dropout_rate: Dropout rate for embedding layer. | |
embed_pad: Embedding padding symbol ID. | |
""" | |
def __init__( | |
self, | |
vocab_size: int, | |
embed_size: int = 256, | |
hidden_size: int = 256, | |
rnn_type: str = "lstm", | |
num_layers: int = 1, | |
dropout_rate: float = 0.0, | |
embed_dropout_rate: float = 0.0, | |
embed_pad: int = 0, | |
use_embed_mask: bool = False, | |
) -> None: | |
"""Construct a RNNDecoder object.""" | |
super().__init__() | |
if rnn_type not in ("lstm", "gru"): | |
raise ValueError(f"Not supported: rnn_type={rnn_type}") | |
self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad) | |
self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate) | |
rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU | |
self.rnn = torch.nn.ModuleList( | |
[rnn_class(embed_size, hidden_size, 1, batch_first=True)] | |
) | |
for _ in range(1, num_layers): | |
self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)] | |
self.dropout_rnn = torch.nn.ModuleList( | |
[torch.nn.Dropout(p=dropout_rate) for _ in range(num_layers)] | |
) | |
self.dlayers = num_layers | |
self.dtype = rnn_type | |
self.output_size = hidden_size | |
self.vocab_size = vocab_size | |
self.device = next(self.parameters()).device | |
self.score_cache = {} | |
self.use_embed_mask = use_embed_mask | |
if self.use_embed_mask: | |
self._embed_mask = SpecAug( | |
time_mask_width_range=3, | |
num_time_mask=4, | |
apply_freq_mask=False, | |
apply_time_warp=False, | |
) | |
def forward( | |
self, | |
labels: torch.Tensor, | |
label_lens: torch.Tensor, | |
states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None, | |
) -> torch.Tensor: | |
"""Encode source label sequences. | |
Args: | |
labels: Label ID sequences. (B, L) | |
states: Decoder hidden states. | |
((N, B, D_dec), (N, B, D_dec) or None) or None | |
Returns: | |
dec_out: Decoder output sequences. (B, U, D_dec) | |
""" | |
if states is None: | |
states = self.init_state(labels.size(0)) | |
dec_embed = self.dropout_embed(self.embed(labels)) | |
if self.use_embed_mask and self.training: | |
dec_embed = self._embed_mask(dec_embed, label_lens)[0] | |
dec_out, states = self.rnn_forward(dec_embed, states) | |
return dec_out | |
def rnn_forward( | |
self, | |
x: torch.Tensor, | |
state: Tuple[torch.Tensor, Optional[torch.Tensor]], | |
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: | |
"""Encode source label sequences. | |
Args: | |
x: RNN input sequences. (B, D_emb) | |
state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) | |
Returns: | |
x: RNN output sequences. (B, D_dec) | |
(h_next, c_next): Decoder hidden states. | |
(N, B, D_dec), (N, B, D_dec) or None) | |
""" | |
h_prev, c_prev = state | |
h_next, c_next = self.init_state(x.size(0)) | |
for layer in range(self.dlayers): | |
if self.dtype == "lstm": | |
x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[ | |
layer | |
](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1])) | |
else: | |
x, h_next[layer : layer + 1] = self.rnn[layer]( | |
x, hx=h_prev[layer : layer + 1] | |
) | |
x = self.dropout_rnn[layer](x) | |
return x, (h_next, c_next) | |
def score( | |
self, | |
label: torch.Tensor, | |
label_sequence: List[int], | |
dec_state: Tuple[torch.Tensor, Optional[torch.Tensor]], | |
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: | |
"""One-step forward hypothesis. | |
Args: | |
label: Previous label. (1, 1) | |
label_sequence: Current label sequence. | |
dec_state: Previous decoder hidden states. | |
((N, 1, D_dec), (N, 1, D_dec) or None) | |
Returns: | |
dec_out: Decoder output sequence. (1, D_dec) | |
dec_state: Decoder hidden states. | |
((N, 1, D_dec), (N, 1, D_dec) or None) | |
""" | |
str_labels = "_".join(map(str, label_sequence)) | |
if str_labels in self.score_cache: | |
dec_out, dec_state = self.score_cache[str_labels] | |
else: | |
dec_embed = self.embed(label) | |
dec_out, dec_state = self.rnn_forward(dec_embed, dec_state) | |
self.score_cache[str_labels] = (dec_out, dec_state) | |
return dec_out[0], dec_state | |
def batch_score( | |
self, | |
hyps: List[Hypothesis], | |
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: | |
"""One-step forward hypotheses. | |
Args: | |
hyps: Hypotheses. | |
Returns: | |
dec_out: Decoder output sequences. (B, D_dec) | |
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) | |
""" | |
labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device) | |
dec_embed = self.embed(labels) | |
states = self.create_batch_states([h.dec_state for h in hyps]) | |
dec_out, states = self.rnn_forward(dec_embed, states) | |
return dec_out.squeeze(1), states | |
def set_device(self, device: torch.device) -> None: | |
"""Set GPU device to use. | |
Args: | |
device: Device ID. | |
""" | |
self.device = device | |
def init_state( | |
self, batch_size: int | |
) -> Tuple[torch.Tensor, Optional[torch.tensor]]: | |
"""Initialize decoder states. | |
Args: | |
batch_size: Batch size. | |
Returns: | |
: Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) | |
""" | |
h_n = torch.zeros( | |
self.dlayers, | |
batch_size, | |
self.output_size, | |
device=self.device, | |
) | |
if self.dtype == "lstm": | |
c_n = torch.zeros( | |
self.dlayers, | |
batch_size, | |
self.output_size, | |
device=self.device, | |
) | |
return (h_n, c_n) | |
return (h_n, None) | |
def select_state( | |
self, states: Tuple[torch.Tensor, Optional[torch.Tensor]], idx: int | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
"""Get specified ID state from decoder hidden states. | |
Args: | |
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) | |
idx: State ID to extract. | |
Returns: | |
: Decoder hidden state for given ID. ((N, 1, D_dec), (N, 1, D_dec) or None) | |
""" | |
return ( | |
states[0][:, idx : idx + 1, :], | |
states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None, | |
) | |
def create_batch_states( | |
self, | |
new_states: List[Tuple[torch.Tensor, Optional[torch.Tensor]]], | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
"""Create decoder hidden states. | |
Args: | |
new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec) or None)] | |
Returns: | |
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) | |
""" | |
return ( | |
torch.cat([s[0] for s in new_states], dim=1), | |
( | |
torch.cat([s[1] for s in new_states], dim=1) | |
if self.dtype == "lstm" | |
else None | |
), | |
) | |