File size: 580 Bytes
affcd23 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import torch
from torch import Tensor
from torch.nn import CTCLoss
class CTCLossWrapper(CTCLoss):
def __init__(self):
super().__init__(zero_infinity=True)
def forward(self, log_probs, log_probs_length, text_encoded, text_encoded_length,
**batch) -> Tensor:
log_probs_t = torch.transpose(log_probs, 0, 1)
return super().forward(
log_probs=log_probs_t,
targets=text_encoded,
input_lengths=log_probs_length,
target_lengths=text_encoded_length,
)
|