File size: 580 Bytes
affcd23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from torch import Tensor
from torch.nn import CTCLoss


class CTCLossWrapper(CTCLoss):
    
    def __init__(self):
        super().__init__(zero_infinity=True)
    
    def forward(self, log_probs, log_probs_length, text_encoded, text_encoded_length,
                **batch) -> Tensor:
        log_probs_t = torch.transpose(log_probs, 0, 1)

        return super().forward(
                log_probs=log_probs_t,
                targets=text_encoded,
                input_lengths=log_probs_length,
                target_lengths=text_encoded_length,
            )