File size: 2,310 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from torch import nn
from torch.nn import Module
from torch.nn import functional as F


class ForwardSumLoss(Module):
    r"""Computes the forward sum loss for sequence-to-sequence models with attention.

    Args:
        blank_logprob (float): The log probability of the blank symbol. Default: -1.

    Attributes:
        log_softmax (nn.LogSoftmax): The log softmax function.
        ctc_loss (nn.CTCLoss): The CTC loss function.
        blank_logprob (float): The log probability of the blank symbol.

    Methods:
        forward: Computes the forward sum loss for sequence-to-sequence models with attention.

    """

    def __init__(self, blank_logprob: float = -1):
        super().__init__()
        self.log_softmax = nn.LogSoftmax(dim=3)
        self.ctc_loss = nn.CTCLoss(zero_infinity=True)
        self.blank_logprob = blank_logprob

    def forward(
        self, attn_logprob: torch.Tensor, in_lens: torch.Tensor, out_lens: torch.Tensor,
    ) -> float:
        r"""Computes the forward sum loss for sequence-to-sequence models with attention.

        Args:
            attn_logprob (torch.Tensor): The attention log probabilities of shape (batch_size, max_out_len, max_in_len).
            in_lens (torch.Tensor): The input lengths of shape (batch_size,).
            out_lens (torch.Tensor): The output lengths of shape (batch_size,).

        Returns:
            float: The forward sum loss.

        """
        key_lens = in_lens
        query_lens = out_lens
        attn_logprob_padded = F.pad(
            input=attn_logprob, pad=(1, 0), value=self.blank_logprob,
        )

        total_loss = 0.0
        for bid in range(attn_logprob.shape[0]):
            target_seq = torch.arange(1, int(key_lens[bid]) + 1).unsqueeze(0)
            curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[
                : int(query_lens[bid]), :, : int(key_lens[bid]) + 1,
            ]

            curr_logprob = self.log_softmax(curr_logprob[None])[0]
            loss = self.ctc_loss(
                curr_logprob,
                target_seq,
                input_lengths=query_lens[bid : bid + 1],
                target_lengths=key_lens[bid : bid + 1],
            )
            total_loss += loss

        total_loss /= attn_logprob.shape[0]
        return total_loss