"""Custom decoder definition for transducer models.""" import torch from espnet.nets.pytorch_backend.transducer.blocks import build_blocks from espnet.nets.pytorch_backend.transducer.utils import check_batch_state from espnet.nets.pytorch_backend.transducer.utils import check_state from espnet.nets.pytorch_backend.transducer.utils import pad_sequence from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask from espnet.nets.transducer_decoder_interface import TransducerDecoderInterface class CustomDecoder(TransducerDecoderInterface, torch.nn.Module): """Custom decoder module for transducer models. Args: odim (int): dimension of outputs dec_arch (list): list of layer definitions input_layer (str): input layer type repeat_block (int): repeat provided blocks N times if N > 1 positional_encoding_type (str): positional encoding type positionwise_layer_type (str): linear positionwise_activation_type (str): positionwise activation type dropout_rate_embed (float): dropout rate for embedding layer (if specified) blank (int): blank symbol ID """ def __init__( self, odim, dec_arch, input_layer="embed", repeat_block=0, joint_activation_type="tanh", positional_encoding_type="abs_pos", positionwise_layer_type="linear", positionwise_activation_type="relu", dropout_rate_embed=0.0, blank=0, ): """Construct a CustomDecoder object.""" torch.nn.Module.__init__(self) self.embed, self.decoders, ddim, _ = build_blocks( "decoder", odim, input_layer, dec_arch, repeat_block=repeat_block, positional_encoding_type=positional_encoding_type, positionwise_layer_type=positionwise_layer_type, positionwise_activation_type=positionwise_activation_type, dropout_rate_embed=dropout_rate_embed, padding_idx=blank, ) self.after_norm = LayerNorm(ddim) self.dlayers = len(self.decoders) self.dunits = ddim self.odim = odim self.blank = blank def set_device(self, device): """Set GPU device to use. Args: device (torch.device): device id """ self.device = device def init_state(self, batch_size=None, device=None, dtype=None): """Initialize decoder states. Args: None Returns: state (list): batch of decoder decoder states [L x None] """ state = [None] * self.dlayers return state def forward(self, tgt, tgt_mask, memory): """Forward custom decoder. Args: tgt (torch.Tensor): input token ids, int64 (batch, maxlen_out) if input_layer == "embed" input tensor (batch, maxlen_out, #mels) in the other cases tgt_mask (torch.Tensor): input token mask, (batch, maxlen_out) dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (include 1.2) memory (torch.Tensor): encoded memory, float32 (batch, maxlen_in, feat) Return: tgt (torch.Tensor): decoder output (batch, maxlen_out, dim_dec) tgt_mask (torch.Tensor): score mask before softmax (batch, maxlen_out) """ tgt = self.embed(tgt) tgt, tgt_mask = self.decoders(tgt, tgt_mask) tgt = self.after_norm(tgt) return tgt, tgt_mask def score(self, hyp, cache): """Forward one step. Args: hyp (dataclass): hypothesis cache (dict): states cache Returns: y (torch.Tensor): decoder outputs (1, dec_dim) (list): decoder states [L x (1, max_len, dec_dim)] lm_tokens (torch.Tensor): token id for LM (1) """ tgt = torch.tensor([hyp.yseq], device=self.device) lm_tokens = tgt[:, -1] str_yseq = "".join(list(map(str, hyp.yseq))) if str_yseq in cache: y, new_state = cache[str_yseq] else: tgt_mask = subsequent_mask(len(hyp.yseq)).unsqueeze_(0) state = check_state(hyp.dec_state, (tgt.size(1) - 1), self.blank) tgt = self.embed(tgt) new_state = [] for s, decoder in zip(state, self.decoders): tgt, tgt_mask = decoder(tgt, tgt_mask, cache=s) new_state.append(tgt) y = self.after_norm(tgt[:, -1]) cache[str_yseq] = (y, new_state) return y[0], new_state, lm_tokens def batch_score(self, hyps, batch_states, cache, use_lm): """Forward batch one step. Args: hyps (list): batch of hypotheses batch_states (list): decoder states [L x (B, max_len, dec_dim)] cache (dict): states cache Returns: batch_y (torch.Tensor): decoder output (B, dec_dim) batch_states (list): decoder states [L x (B, max_len, dec_dim)] lm_tokens (torch.Tensor): batch of token ids for LM (B) """ final_batch = len(hyps) process = [] done = [None for _ in range(final_batch)] for i, hyp in enumerate(hyps): str_yseq = "".join(list(map(str, hyp.yseq))) if str_yseq in cache: done[i] = cache[str_yseq] else: process.append((str_yseq, hyp.yseq, hyp.dec_state)) if process: _tokens = pad_sequence([p[1] for p in process], self.blank) batch_tokens = torch.LongTensor(_tokens, device=self.device) tgt_mask = ( subsequent_mask(batch_tokens.size(-1)) .unsqueeze_(0) .expand(len(process), -1, -1) ) dec_state = self.create_batch_states( self.init_state(), [p[2] for p in process], _tokens, ) tgt = self.embed(batch_tokens) next_state = [] for s, decoder in zip(dec_state, self.decoders): tgt, tgt_mask = decoder(tgt, tgt_mask, cache=s) next_state.append(tgt) tgt = self.after_norm(tgt[:, -1]) j = 0 for i in range(final_batch): if done[i] is None: new_state = self.select_state(next_state, j) done[i] = (tgt[j], new_state) cache[process[j][0]] = (tgt[j], new_state) j += 1 self.create_batch_states( batch_states, [d[1] for d in done], [[0] + h.yseq for h in hyps] ) batch_y = torch.stack([d[0] for d in done]) if use_lm: lm_tokens = torch.LongTensor( [hyp.yseq[-1] for hyp in hyps], device=self.device ) return batch_y, batch_states, lm_tokens return batch_y, batch_states, None def select_state(self, batch_states, idx): """Get decoder state from batch of states, for given id. Args: batch_states (list): batch of decoder states [L x (B, max_len, dec_dim)] idx (int): index to extract state from batch of states Returns: state_idx (list): decoder states for given id [L x (1, max_len, dec_dim)] """ if batch_states[0] is None: return batch_states state_idx = [batch_states[layer][idx] for layer in range(self.dlayers)] return state_idx def create_batch_states(self, batch_states, l_states, check_list): """Create batch of decoder states. Args: batch_states (list): batch of decoder states [L x (B, max_len, dec_dim)] l_states (list): list of decoder states [B x [L x (1, max_len, dec_dim)]] check_list (list): list of sequences for max_len Returns: batch_states (list): batch of decoder states [L x (B, max_len, dec_dim)] """ if l_states[0][0] is None: return batch_states max_len = max(len(elem) for elem in check_list) - 1 for layer in range(self.dlayers): batch_states[layer] = check_batch_state( [s[layer] for s in l_states], max_len, self.blank ) return batch_states