|
"""RNN decoder for transducer-based models.""" |
|
|
|
import torch |
|
|
|
from espnet.nets.transducer_decoder_interface import TransducerDecoderInterface |
|
|
|
|
|
class DecoderRNNT(TransducerDecoderInterface, torch.nn.Module): |
|
"""RNN-T Decoder module. |
|
|
|
Args: |
|
odim (int): dimension of outputs |
|
dtype (str): gru or lstm |
|
dlayers (int): # prediction layers |
|
dunits (int): # prediction units |
|
blank (int): blank symbol id |
|
embed_dim (int): dimension of embeddings |
|
dropout (float): dropout rate |
|
dropout_embed (float): embedding dropout rate |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
odim, |
|
dtype, |
|
dlayers, |
|
dunits, |
|
blank, |
|
embed_dim, |
|
dropout=0.0, |
|
dropout_embed=0.0, |
|
): |
|
"""Transducer initializer.""" |
|
super().__init__() |
|
|
|
self.embed = torch.nn.Embedding(odim, embed_dim, padding_idx=blank) |
|
self.dropout_embed = torch.nn.Dropout(p=dropout_embed) |
|
|
|
dec_net = torch.nn.LSTM if dtype == "lstm" else torch.nn.GRU |
|
|
|
self.decoder = torch.nn.ModuleList( |
|
[dec_net(embed_dim, dunits, 1, batch_first=True)] |
|
) |
|
self.dropout_dec = torch.nn.Dropout(p=dropout) |
|
|
|
for _ in range(1, dlayers): |
|
self.decoder += [dec_net(dunits, dunits, 1, batch_first=True)] |
|
|
|
self.dlayers = dlayers |
|
self.dunits = dunits |
|
self.dtype = dtype |
|
|
|
self.odim = odim |
|
|
|
self.ignore_id = -1 |
|
self.blank = blank |
|
|
|
self.multi_gpus = torch.cuda.device_count() > 1 |
|
|
|
def set_device(self, device): |
|
"""Set GPU device to use. |
|
|
|
Args: |
|
device (torch.device): device id |
|
|
|
""" |
|
self.device = device |
|
|
|
def set_data_type(self, data_type): |
|
"""Set GPU device to use. |
|
|
|
Args: |
|
data_type (torch.dtype): Tensor data type |
|
|
|
""" |
|
self.data_type = data_type |
|
|
|
def init_state(self, batch_size): |
|
"""Initialize decoder states. |
|
|
|
Args: |
|
batch_size (int): Batch size |
|
|
|
Returns: |
|
(tuple): batch of decoder states |
|
((L, B, dec_dim), (L, B, dec_dim)) |
|
|
|
""" |
|
h_n = torch.zeros( |
|
self.dlayers, |
|
batch_size, |
|
self.dunits, |
|
device=self.device, |
|
dtype=self.data_type, |
|
) |
|
|
|
if self.dtype == "lstm": |
|
c_n = torch.zeros( |
|
self.dlayers, |
|
batch_size, |
|
self.dunits, |
|
device=self.device, |
|
dtype=self.data_type, |
|
) |
|
|
|
return (h_n, c_n) |
|
|
|
return (h_n, None) |
|
|
|
def rnn_forward(self, y, state): |
|
"""RNN forward. |
|
|
|
Args: |
|
y (torch.Tensor): batch of input features (B, emb_dim) |
|
state (tuple): batch of decoder states |
|
((L, B, dec_dim), (L, B, dec_dim)) |
|
|
|
Returns: |
|
y (torch.Tensor): batch of output features (B, dec_dim) |
|
(tuple): batch of decoder states |
|
(L, B, dec_dim), (L, B, dec_dim)) |
|
|
|
""" |
|
h_prev, c_prev = state |
|
h_next, c_next = self.init_state(y.size(0)) |
|
|
|
for layer in range(self.dlayers): |
|
if self.dtype == "lstm": |
|
y, ( |
|
h_next[layer : layer + 1], |
|
c_next[layer : layer + 1], |
|
) = self.decoder[layer]( |
|
y, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1]) |
|
) |
|
else: |
|
y, h_next[layer : layer + 1] = self.decoder[layer]( |
|
y, hx=h_prev[layer : layer + 1] |
|
) |
|
|
|
y = self.dropout_dec(y) |
|
|
|
return y, (h_next, c_next) |
|
|
|
def forward(self, hs_pad, ys_in_pad): |
|
"""Forward function for transducer. |
|
|
|
Args: |
|
hs_pad (torch.Tensor): |
|
batch of padded hidden state sequences (B, Tmax, D) |
|
ys_in_pad (torch.Tensor): |
|
batch of padded character id sequence tensor (B, Lmax+1) |
|
|
|
Returns: |
|
z (torch.Tensor): output (B, T, U, odim) |
|
|
|
""" |
|
self.set_device(hs_pad.device) |
|
self.set_data_type(hs_pad.dtype) |
|
|
|
state = self.init_state(hs_pad.size(0)) |
|
eys = self.dropout_embed(self.embed(ys_in_pad)) |
|
|
|
h_dec, _ = self.rnn_forward(eys, state) |
|
|
|
return h_dec |
|
|
|
def score(self, hyp, cache): |
|
"""Forward one step. |
|
|
|
Args: |
|
hyp (dataclass): hypothesis |
|
cache (dict): states cache |
|
|
|
Returns: |
|
y (torch.Tensor): decoder outputs (1, dec_dim) |
|
state (tuple): decoder states |
|
((L, 1, dec_dim), (L, 1, dec_dim)), |
|
(torch.Tensor): token id for LM (1,) |
|
|
|
""" |
|
vy = torch.full((1, 1), hyp.yseq[-1], dtype=torch.long, device=self.device) |
|
|
|
str_yseq = "".join(list(map(str, hyp.yseq))) |
|
|
|
if str_yseq in cache: |
|
y, state = cache[str_yseq] |
|
else: |
|
ey = self.embed(vy) |
|
|
|
y, state = self.rnn_forward(ey, hyp.dec_state) |
|
cache[str_yseq] = (y, state) |
|
|
|
return y[0][0], state, vy[0] |
|
|
|
def batch_score(self, hyps, batch_states, cache, use_lm): |
|
"""Forward batch one step. |
|
|
|
Args: |
|
hyps (list): batch of hypotheses |
|
batch_states (tuple): batch of decoder states |
|
((L, B, dec_dim), (L, B, dec_dim)) |
|
cache (dict): states cache |
|
use_lm (bool): whether a LM is used for decoding |
|
|
|
Returns: |
|
batch_y (torch.Tensor): decoder output (B, dec_dim) |
|
batch_states (tuple): batch of decoder states |
|
((L, B, dec_dim), (L, B, dec_dim)) |
|
lm_tokens (torch.Tensor): batch of token ids for LM (B) |
|
|
|
""" |
|
final_batch = len(hyps) |
|
|
|
process = [] |
|
done = [None] * final_batch |
|
|
|
for i, hyp in enumerate(hyps): |
|
str_yseq = "".join(list(map(str, hyp.yseq))) |
|
|
|
if str_yseq in cache: |
|
done[i] = cache[str_yseq] |
|
else: |
|
process.append((str_yseq, hyp.yseq[-1], hyp.dec_state)) |
|
|
|
if process: |
|
tokens = torch.LongTensor([[p[1]] for p in process], device=self.device) |
|
dec_state = self.create_batch_states( |
|
self.init_state(tokens.size(0)), [p[2] for p in process] |
|
) |
|
|
|
ey = self.embed(tokens) |
|
y, dec_state = self.rnn_forward(ey, dec_state) |
|
|
|
j = 0 |
|
for i in range(final_batch): |
|
if done[i] is None: |
|
new_state = self.select_state(dec_state, j) |
|
|
|
done[i] = (y[j], new_state) |
|
cache[process[j][0]] = (y[j], new_state) |
|
|
|
j += 1 |
|
|
|
batch_y = torch.cat([d[0] for d in done], dim=0) |
|
batch_states = self.create_batch_states(batch_states, [d[1] for d in done]) |
|
|
|
if use_lm: |
|
lm_tokens = torch.LongTensor([h.yseq[-1] for h in hyps], device=self.device) |
|
|
|
return batch_y, batch_states, lm_tokens |
|
|
|
return batch_y, batch_states, None |
|
|
|
def select_state(self, batch_states, idx): |
|
"""Get decoder state from batch of states, for given id. |
|
|
|
Args: |
|
batch_states (tuple): batch of decoder states |
|
((L, B, dec_dim), (L, B, dec_dim)) |
|
idx (int): index to extract state from batch of states |
|
|
|
Returns: |
|
(tuple): decoder states for given id |
|
((L, 1, dec_dim), (L, 1, dec_dim)) |
|
|
|
""" |
|
return ( |
|
batch_states[0][:, idx : idx + 1, :], |
|
batch_states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None, |
|
) |
|
|
|
def create_batch_states(self, batch_states, l_states, l_tokens=None): |
|
"""Create batch of decoder states. |
|
|
|
Args: |
|
batch_states (tuple): batch of decoder states |
|
((L, B, dec_dim), (L, B, dec_dim)) |
|
l_states (list): list of decoder states |
|
[L x ((1, dec_dim), (1, dec_dim))] |
|
|
|
Returns: |
|
batch_states (tuple): batch of decoder states |
|
((L, B, dec_dim), (L, B, dec_dim)) |
|
|
|
""" |
|
return ( |
|
torch.cat([s[0] for s in l_states], dim=1), |
|
torch.cat([s[1] for s in l_states], dim=1) |
|
if self.dtype == "lstm" |
|
else None, |
|
) |
|
|