import torch from torch import nn from torch.nn import Module from torch.nn import functional as F class ForwardSumLoss(Module): r"""Computes the forward sum loss for sequence-to-sequence models with attention. Args: blank_logprob (float): The log probability of the blank symbol. Default: -1. Attributes: log_softmax (nn.LogSoftmax): The log softmax function. ctc_loss (nn.CTCLoss): The CTC loss function. blank_logprob (float): The log probability of the blank symbol. Methods: forward: Computes the forward sum loss for sequence-to-sequence models with attention. """ def __init__(self, blank_logprob: float = -1): super().__init__() self.log_softmax = nn.LogSoftmax(dim=3) self.ctc_loss = nn.CTCLoss(zero_infinity=True) self.blank_logprob = blank_logprob def forward( self, attn_logprob: torch.Tensor, in_lens: torch.Tensor, out_lens: torch.Tensor, ) -> float: r"""Computes the forward sum loss for sequence-to-sequence models with attention. Args: attn_logprob (torch.Tensor): The attention log probabilities of shape (batch_size, max_out_len, max_in_len). in_lens (torch.Tensor): The input lengths of shape (batch_size,). out_lens (torch.Tensor): The output lengths of shape (batch_size,). Returns: float: The forward sum loss. """ key_lens = in_lens query_lens = out_lens attn_logprob_padded = F.pad( input=attn_logprob, pad=(1, 0), value=self.blank_logprob, ) total_loss = 0.0 for bid in range(attn_logprob.shape[0]): target_seq = torch.arange(1, int(key_lens[bid]) + 1).unsqueeze(0) curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[ : int(query_lens[bid]), :, : int(key_lens[bid]) + 1, ] curr_logprob = self.log_softmax(curr_logprob[None])[0] loss = self.ctc_loss( curr_logprob, target_seq, input_lengths=query_lens[bid : bid + 1], target_lengths=key_lens[bid : bid + 1], ) total_loss += loss total_loss /= attn_logprob.shape[0] return total_loss