tgritsaev's picture
Upload 198 files
affcd23 verified
raw
history blame
580 Bytes
import torch
from torch import Tensor
from torch.nn import CTCLoss
class CTCLossWrapper(CTCLoss):
def __init__(self):
super().__init__(zero_infinity=True)
def forward(self, log_probs, log_probs_length, text_encoded, text_encoded_length,
**batch) -> Tensor:
log_probs_t = torch.transpose(log_probs, 0, 1)
return super().forward(
log_probs=log_probs_t,
targets=text_encoded,
input_lengths=log_probs_length,
target_lengths=text_encoded_length,
)