tobiasc's picture
Initial commit
ad16788
"""RNN decoder for transducer-based models."""
import torch
from espnet.nets.transducer_decoder_interface import TransducerDecoderInterface
class DecoderRNNT(TransducerDecoderInterface, torch.nn.Module):
"""RNN-T Decoder module.
Args:
odim (int): dimension of outputs
dtype (str): gru or lstm
dlayers (int): # prediction layers
dunits (int): # prediction units
blank (int): blank symbol id
embed_dim (int): dimension of embeddings
dropout (float): dropout rate
dropout_embed (float): embedding dropout rate
"""
def __init__(
self,
odim,
dtype,
dlayers,
dunits,
blank,
embed_dim,
dropout=0.0,
dropout_embed=0.0,
):
"""Transducer initializer."""
super().__init__()
self.embed = torch.nn.Embedding(odim, embed_dim, padding_idx=blank)
self.dropout_embed = torch.nn.Dropout(p=dropout_embed)
dec_net = torch.nn.LSTM if dtype == "lstm" else torch.nn.GRU
self.decoder = torch.nn.ModuleList(
[dec_net(embed_dim, dunits, 1, batch_first=True)]
)
self.dropout_dec = torch.nn.Dropout(p=dropout)
for _ in range(1, dlayers):
self.decoder += [dec_net(dunits, dunits, 1, batch_first=True)]
self.dlayers = dlayers
self.dunits = dunits
self.dtype = dtype
self.odim = odim
self.ignore_id = -1
self.blank = blank
self.multi_gpus = torch.cuda.device_count() > 1
def set_device(self, device):
"""Set GPU device to use.
Args:
device (torch.device): device id
"""
self.device = device
def set_data_type(self, data_type):
"""Set GPU device to use.
Args:
data_type (torch.dtype): Tensor data type
"""
self.data_type = data_type
def init_state(self, batch_size):
"""Initialize decoder states.
Args:
batch_size (int): Batch size
Returns:
(tuple): batch of decoder states
((L, B, dec_dim), (L, B, dec_dim))
"""
h_n = torch.zeros(
self.dlayers,
batch_size,
self.dunits,
device=self.device,
dtype=self.data_type,
)
if self.dtype == "lstm":
c_n = torch.zeros(
self.dlayers,
batch_size,
self.dunits,
device=self.device,
dtype=self.data_type,
)
return (h_n, c_n)
return (h_n, None)
def rnn_forward(self, y, state):
"""RNN forward.
Args:
y (torch.Tensor): batch of input features (B, emb_dim)
state (tuple): batch of decoder states
((L, B, dec_dim), (L, B, dec_dim))
Returns:
y (torch.Tensor): batch of output features (B, dec_dim)
(tuple): batch of decoder states
(L, B, dec_dim), (L, B, dec_dim))
"""
h_prev, c_prev = state
h_next, c_next = self.init_state(y.size(0))
for layer in range(self.dlayers):
if self.dtype == "lstm":
y, (
h_next[layer : layer + 1],
c_next[layer : layer + 1],
) = self.decoder[layer](
y, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1])
)
else:
y, h_next[layer : layer + 1] = self.decoder[layer](
y, hx=h_prev[layer : layer + 1]
)
y = self.dropout_dec(y)
return y, (h_next, c_next)
def forward(self, hs_pad, ys_in_pad):
"""Forward function for transducer.
Args:
hs_pad (torch.Tensor):
batch of padded hidden state sequences (B, Tmax, D)
ys_in_pad (torch.Tensor):
batch of padded character id sequence tensor (B, Lmax+1)
Returns:
z (torch.Tensor): output (B, T, U, odim)
"""
self.set_device(hs_pad.device)
self.set_data_type(hs_pad.dtype)
state = self.init_state(hs_pad.size(0))
eys = self.dropout_embed(self.embed(ys_in_pad))
h_dec, _ = self.rnn_forward(eys, state)
return h_dec
def score(self, hyp, cache):
"""Forward one step.
Args:
hyp (dataclass): hypothesis
cache (dict): states cache
Returns:
y (torch.Tensor): decoder outputs (1, dec_dim)
state (tuple): decoder states
((L, 1, dec_dim), (L, 1, dec_dim)),
(torch.Tensor): token id for LM (1,)
"""
vy = torch.full((1, 1), hyp.yseq[-1], dtype=torch.long, device=self.device)
str_yseq = "".join(list(map(str, hyp.yseq)))
if str_yseq in cache:
y, state = cache[str_yseq]
else:
ey = self.embed(vy)
y, state = self.rnn_forward(ey, hyp.dec_state)
cache[str_yseq] = (y, state)
return y[0][0], state, vy[0]
def batch_score(self, hyps, batch_states, cache, use_lm):
"""Forward batch one step.
Args:
hyps (list): batch of hypotheses
batch_states (tuple): batch of decoder states
((L, B, dec_dim), (L, B, dec_dim))
cache (dict): states cache
use_lm (bool): whether a LM is used for decoding
Returns:
batch_y (torch.Tensor): decoder output (B, dec_dim)
batch_states (tuple): batch of decoder states
((L, B, dec_dim), (L, B, dec_dim))
lm_tokens (torch.Tensor): batch of token ids for LM (B)
"""
final_batch = len(hyps)
process = []
done = [None] * final_batch
for i, hyp in enumerate(hyps):
str_yseq = "".join(list(map(str, hyp.yseq)))
if str_yseq in cache:
done[i] = cache[str_yseq]
else:
process.append((str_yseq, hyp.yseq[-1], hyp.dec_state))
if process:
tokens = torch.LongTensor([[p[1]] for p in process], device=self.device)
dec_state = self.create_batch_states(
self.init_state(tokens.size(0)), [p[2] for p in process]
)
ey = self.embed(tokens)
y, dec_state = self.rnn_forward(ey, dec_state)
j = 0
for i in range(final_batch):
if done[i] is None:
new_state = self.select_state(dec_state, j)
done[i] = (y[j], new_state)
cache[process[j][0]] = (y[j], new_state)
j += 1
batch_y = torch.cat([d[0] for d in done], dim=0)
batch_states = self.create_batch_states(batch_states, [d[1] for d in done])
if use_lm:
lm_tokens = torch.LongTensor([h.yseq[-1] for h in hyps], device=self.device)
return batch_y, batch_states, lm_tokens
return batch_y, batch_states, None
def select_state(self, batch_states, idx):
"""Get decoder state from batch of states, for given id.
Args:
batch_states (tuple): batch of decoder states
((L, B, dec_dim), (L, B, dec_dim))
idx (int): index to extract state from batch of states
Returns:
(tuple): decoder states for given id
((L, 1, dec_dim), (L, 1, dec_dim))
"""
return (
batch_states[0][:, idx : idx + 1, :],
batch_states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None,
)
def create_batch_states(self, batch_states, l_states, l_tokens=None):
"""Create batch of decoder states.
Args:
batch_states (tuple): batch of decoder states
((L, B, dec_dim), (L, B, dec_dim))
l_states (list): list of decoder states
[L x ((1, dec_dim), (1, dec_dim))]
Returns:
batch_states (tuple): batch of decoder states
((L, B, dec_dim), (L, B, dec_dim))
"""
return (
torch.cat([s[0] for s in l_states], dim=1),
torch.cat([s[1] for s in l_states], dim=1)
if self.dtype == "lstm"
else None,
)