import torch | |
from torch import Tensor | |
from torch.nn import CTCLoss | |
class CTCLossWrapper(CTCLoss): | |
def __init__(self): | |
super().__init__(zero_infinity=True) | |
def forward(self, log_probs, log_probs_length, text_encoded, text_encoded_length, | |
**batch) -> Tensor: | |
log_probs_t = torch.transpose(log_probs, 0, 1) | |
return super().forward( | |
log_probs=log_probs_t, | |
targets=text_encoded, | |
input_lengths=log_probs_length, | |
target_lengths=text_encoded_length, | |
) | |