|
|
|
|
|
|
|
"""GTN CTC implementation.""" |
|
|
|
import gtn |
|
import torch |
|
|
|
|
|
class GTNCTCLossFunction(torch.autograd.Function): |
|
"""GTN CTC module.""" |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
def create_ctc_graph(target, blank_idx): |
|
"""Build gtn graph. |
|
|
|
:param list target: single target sequence |
|
:param int blank_idx: index of blank token |
|
:return: gtn graph of target sequence |
|
:rtype: gtn.Graph |
|
""" |
|
g_criterion = gtn.Graph(False) |
|
L = len(target) |
|
S = 2 * L + 1 |
|
for s in range(S): |
|
idx = (s - 1) // 2 |
|
g_criterion.add_node(s == 0, s == S - 1 or s == S - 2) |
|
label = target[idx] if s % 2 else blank_idx |
|
g_criterion.add_arc(s, s, label) |
|
if s > 0: |
|
g_criterion.add_arc(s - 1, s, label) |
|
if s % 2 and s > 1 and label != target[idx - 1]: |
|
g_criterion.add_arc(s - 2, s, label) |
|
g_criterion.arc_sort(False) |
|
return g_criterion |
|
|
|
@staticmethod |
|
def forward(ctx, log_probs, targets, blank_idx=0, reduction="none"): |
|
"""Forward computation. |
|
|
|
:param torch.tensor log_probs: batched log softmax probabilities (B, Tmax, oDim) |
|
:param list targets: batched target sequences, list of lists |
|
:param int blank_idx: index of blank token |
|
:return: ctc loss value |
|
:rtype: torch.Tensor |
|
""" |
|
B, T, C = log_probs.shape |
|
losses = [None] * B |
|
scales = [None] * B |
|
emissions_graphs = [None] * B |
|
|
|
def process(b): |
|
|
|
g_emissions = gtn.linear_graph(T, C, log_probs.requires_grad) |
|
cpu_data = log_probs[b].cpu().contiguous() |
|
g_emissions.set_weights(cpu_data.data_ptr()) |
|
|
|
|
|
g_criterion = GTNCTCLossFunction.create_ctc_graph(targets[b], blank_idx) |
|
|
|
g_loss = gtn.negate( |
|
gtn.forward_score(gtn.intersect(g_emissions, g_criterion)) |
|
) |
|
|
|
scale = 1.0 |
|
if reduction == "mean": |
|
L = len(targets[b]) |
|
scale = 1.0 / L if L > 0 else scale |
|
elif reduction != "none": |
|
raise ValueError("invalid value for reduction '" + str(reduction) + "'") |
|
|
|
|
|
losses[b] = g_loss |
|
scales[b] = scale |
|
emissions_graphs[b] = g_emissions |
|
|
|
gtn.parallel_for(process, range(B)) |
|
|
|
ctx.auxiliary_data = (losses, scales, emissions_graphs, log_probs.shape) |
|
loss = torch.tensor([losses[b].item() * scales[b] for b in range(B)]) |
|
return torch.mean(loss.cuda() if log_probs.is_cuda else loss) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
"""Backward computation. |
|
|
|
:param torch.tensor grad_output: backward passed gradient value |
|
:return: cumulative gradient output |
|
:rtype: (torch.Tensor, None, None, None) |
|
""" |
|
losses, scales, emissions_graphs, in_shape = ctx.auxiliary_data |
|
B, T, C = in_shape |
|
input_grad = torch.empty((B, T, C)) |
|
|
|
def process(b): |
|
gtn.backward(losses[b], False) |
|
emissions = emissions_graphs[b] |
|
grad = emissions.grad().weights_to_numpy() |
|
input_grad[b] = torch.from_numpy(grad).view(1, T, C) * scales[b] |
|
|
|
gtn.parallel_for(process, range(B)) |
|
|
|
if grad_output.is_cuda: |
|
input_grad = input_grad.cuda() |
|
input_grad *= grad_output / B |
|
|
|
return ( |
|
input_grad, |
|
None, |
|
None, |
|
None, |
|
) |
|
|