Spaces:
Running
Running
File size: 6,286 Bytes
67c46fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import logging
import torch
import torch.nn.functional as F
class CTC(torch.nn.Module):
"""CTC module.
Args:
odim: dimension of outputs
encoder_output_size: number of encoder projection units
dropout_rate: dropout rate (0.0 ~ 1.0)
ctc_type: builtin or warpctc
reduce: reduce the CTC loss into a scalar
"""
def __init__(
self,
odim: int,
encoder_output_size: int,
dropout_rate: float = 0.0,
ctc_type: str = "builtin",
reduce: bool = True,
ignore_nan_grad: bool = True,
):
super().__init__()
eprojs = encoder_output_size
self.dropout_rate = dropout_rate
self.ctc_lo = torch.nn.Linear(eprojs, odim)
self.ctc_type = ctc_type
self.ignore_nan_grad = ignore_nan_grad
if self.ctc_type == "builtin":
self.ctc_loss = torch.nn.CTCLoss(reduction="none")
elif self.ctc_type == "warpctc":
import warpctc_pytorch as warp_ctc
if ignore_nan_grad:
logging.warning("ignore_nan_grad option is not supported for warp_ctc")
self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce)
else:
raise ValueError(
f'ctc_type must be "builtin" or "warpctc": {self.ctc_type}'
)
self.reduce = reduce
def loss_fn(self, th_pred, th_target, th_ilen, th_olen) -> torch.Tensor:
if self.ctc_type == "builtin":
th_pred = th_pred.log_softmax(2)
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
if loss.requires_grad and self.ignore_nan_grad:
# ctc_grad: (L, B, O)
ctc_grad = loss.grad_fn(torch.ones_like(loss))
ctc_grad = ctc_grad.sum([0, 2])
indices = torch.isfinite(ctc_grad)
size = indices.long().sum()
if size == 0:
# Return as is
logging.warning(
"All samples in this mini-batch got nan grad."
" Returning nan value instead of CTC loss"
)
elif size != th_pred.size(1):
logging.warning(
f"{th_pred.size(1) - size}/{th_pred.size(1)}"
" samples got nan grad."
" These were ignored for CTC loss."
)
# Create mask for target
target_mask = torch.full(
[th_target.size(0)],
1,
dtype=torch.bool,
device=th_target.device,
)
s = 0
for ind, le in enumerate(th_olen):
if not indices[ind]:
target_mask[s : s + le] = 0
s += le
# Calc loss again using maksed data
loss = self.ctc_loss(
th_pred[:, indices, :],
th_target[target_mask],
th_ilen[indices],
th_olen[indices],
)
else:
size = th_pred.size(1)
if self.reduce:
# Batch-size average
loss = loss.sum() / size
else:
loss = loss / size
return loss
elif self.ctc_type == "warpctc":
# warpctc only supports float32
th_pred = th_pred.to(dtype=torch.float32)
th_target = th_target.cpu().int()
th_ilen = th_ilen.cpu().int()
th_olen = th_olen.cpu().int()
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
if self.reduce:
# NOTE: sum() is needed to keep consistency since warpctc
# return as tensor w/ shape (1,)
# but builtin return as tensor w/o shape (scalar).
loss = loss.sum()
return loss
elif self.ctc_type == "gtnctc":
log_probs = torch.nn.functional.log_softmax(th_pred, dim=2)
return self.ctc_loss(log_probs, th_target, th_ilen, 0, "none")
else:
raise NotImplementedError
def forward(self, hs_pad, hlens, ys_pad, ys_lens):
"""Calculate CTC loss.
Args:
hs_pad: batch of padded hidden state sequences (B, Tmax, D)
hlens: batch of lengths of hidden state sequences (B)
ys_pad: batch of padded character id sequence tensor (B, Lmax)
ys_lens: batch of lengths of character sequence (B)
"""
# hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab)
ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
if self.ctc_type == "gtnctc":
# gtn expects list form for ys
ys_true = [y[y != -1] for y in ys_pad] # parse padded ys
else:
# ys_hat: (B, L, D) -> (L, B, D)
ys_hat = ys_hat.transpose(0, 1)
# (B, L) -> (BxL,)
ys_true = torch.cat([ys_pad[i, :l] for i, l in enumerate(ys_lens)])
loss = self.loss_fn(ys_hat, ys_true, hlens, ys_lens).to(
device=hs_pad.device, dtype=hs_pad.dtype
)
return loss
def softmax(self, hs_pad):
"""softmax of frame activations
Args:
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
Returns:
torch.Tensor: softmax applied 3d tensor (B, Tmax, odim)
"""
return F.softmax(self.ctc_lo(hs_pad), dim=2)
def log_softmax(self, hs_pad):
"""log_softmax of frame activations
Args:
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
Returns:
torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
"""
return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
def argmax(self, hs_pad):
"""argmax of frame activations
Args:
torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
Returns:
torch.Tensor: argmax applied 2d tensor (B, Tmax)
"""
return torch.argmax(self.ctc_lo(hs_pad), dim=2)
|