File size: 2,247 Bytes
9dce458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

import torch
import custom_ctc_cpp

Tensor = torch.Tensor

class CustomCTCLossFunction(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        log_probs: Tensor,
        targets: Tensor,
        realval: Tensor,
        targets_realval: Tensor,
        input_lengths: Tensor,
        target_lengths: Tensor,
        sigma: float = 1,
        blank: int = 0,
        blank1: int = 0,
        reduction: str = "mean",
        zero_infinity: bool = False
        ):
        assert reduction in ['none', 'mean']
        if isinstance(input_lengths, list) :
            input_lengths = Tensor(input_lengths).long().to(log_probs.device)
        if isinstance(target_lengths, list) :
            target_lengths = Tensor(target_lengths).long().to(log_probs.device)
        neg_log_likelihood, log_alpha = custom_ctc_cpp.forward(log_probs, targets, realval, targets_realval, input_lengths, target_lengths, sigma, blank, blank1, zero_infinity)
        ctx.save_for_backward(neg_log_likelihood, log_alpha, log_probs, targets, realval, targets_realval, input_lengths, target_lengths)
        ctx.blank = blank
        ctx.blank1 = blank1
        ctx.zero_infinity = zero_infinity
        ctx.sigma = sigma
        ctx.reduction = reduction
        if reduction == 'mean' :
            return (neg_log_likelihood / target_lengths.clamp_min(1)).mean()
        return neg_log_likelihood

    @staticmethod
    def backward(ctx, grad_out):
        neg_log_likelihood, log_alpha, log_probs, targets, realval, targets_realval, input_lengths, target_lengths = ctx.saved_tensors
        if ctx.reduction == 'mean' :
            if grad_out.numel() == 0 :
                grad_out = torch.ones_like(neg_log_likelihood)
            else :
                grad_out = grad_out.view(1).tile(neg_log_likelihood.size(0))
            grad_out /= target_lengths.clamp_min(1)
            grad_out /= log_probs.size(0)
        outputs_cls, outputs_realval = custom_ctc_cpp.backward(grad_out, log_probs, targets, realval, targets_realval, input_lengths, target_lengths, neg_log_likelihood, log_alpha, ctx.sigma, ctx.blank, ctx.blank1, ctx.zero_infinity)
        return outputs_cls, None, outputs_realval, None, None, None, None, None, None, None, None