testapi / training /ocr /test_ctc.py
Sunday01's picture
up
9dce458
import torch
import custom_ctc
import custom_ctc_gpu
import numpy as np
import torch.nn.functional as F
from torch.autograd import gradcheck
custom_ctc_f = custom_ctc.CustomCTCLossFunction.apply
custom_ctc_f_gpu = custom_ctc_gpu.CustomCTCLossFunction.apply
def test_ctc_loss_custom(device):
batch_size = 64
num_labels = 101
target_length = 15
gradcheck_input_size = 10
ZERO_NONE = 0
ZERO_SOME = 1
ZERO_ALL = 2
# input_length, vary_lengths, zero_lengths
tests = [(150, False, ZERO_NONE),
(150, True, ZERO_NONE),
(50, True, ZERO_SOME),
(50, True, ZERO_ALL)]
tests += [(50, False, ZERO_NONE),
(50, True, ZERO_NONE),
(150, True, ZERO_SOME),
(150, True, ZERO_ALL)]
for input_length, vary_lengths, zero_mode in tests:
targets = torch.randint(1, num_labels, (batch_size, target_length),
device=device, dtype=torch.long)
x = torch.randn(gradcheck_input_size, dtype=torch.double, device=device, requires_grad=True)
tile_factors = torch.randn(input_length * batch_size * num_labels // gradcheck_input_size + 1,
device=device)
input_lengths = [(torch.randint(input_length // 2, input_length + 1, ()).item()
if vary_lengths or i == 0 else input_length) for i in range(batch_size)]
if zero_mode == ZERO_ALL:
target_lengths = [0 for _ in range(batch_size)]
else:
target_lengths = [(torch.randint(target_length // 2, target_length + 1, ()).item()
if vary_lengths else target_length) for _ in range(batch_size)]
if zero_mode == ZERO_SOME:
idxes = torch.randint(0, batch_size, (10,))
for i in idxes:
target_lengths[i] = 0
num_realval = np.random.randint(1, 16)
rv_x = torch.randn(gradcheck_input_size, dtype=torch.double, device=device, requires_grad=True)
tile_factors_rv = torch.randn(batch_size * input_length * num_realval // gradcheck_input_size + 1,
device=device)
targets_realvals = torch.randn(batch_size, input_length, num_realval, dtype=torch.double)
blank1 = np.random.randint(1, num_labels - 1)
def ctc_after_softmax(x, rv):
x_full = ((x[:, None] * tile_factors[None, :]).view(-1)[:input_length * batch_size * num_labels]
.view(batch_size, input_length, num_labels))
rv_full = ((rv[:, None] * tile_factors_rv[None, :]).view(-1)[:input_length * batch_size * num_realval]
.view(batch_size, input_length, num_realval))
log_probs = torch.log_softmax(x_full, 2)
return custom_ctc_f(log_probs, targets, rv_full, targets_realvals, input_lengths, target_lengths, 2.2, 0, blank1, 'mean', True)
gradcheck(ctc_after_softmax, [x, rv_x])
def test_ctc_loss_custom_gpu(device, fp = torch.float32):
print('testing GPU gradient for %s' % str(fp))
batch_size = 64
num_labels = 101
target_length = 15
gradcheck_input_size = 10
ZERO_NONE = 0
ZERO_SOME = 1
ZERO_ALL = 2
# input_length, vary_lengths, zero_lengths
tests = [(150, False, ZERO_NONE),
(150, True, ZERO_NONE),
(50, True, ZERO_SOME),
(50, True, ZERO_ALL)]
tests += [(50, False, ZERO_NONE),
(50, True, ZERO_NONE),
(150, True, ZERO_SOME),
(150, True, ZERO_ALL)]
for input_length, vary_lengths, zero_mode in tests:
targets = torch.randint(1, num_labels, (batch_size, target_length),
device=device, dtype=torch.long)
x = torch.randn(gradcheck_input_size, dtype=fp, device=device)
tile_factors = torch.randn(input_length * batch_size * num_labels // gradcheck_input_size + 1,
device=device)
input_lengths = [(torch.randint(input_length // 2, input_length + 1, ()).item()
if vary_lengths or i == 0 else input_length) for i in range(batch_size)]
if zero_mode == ZERO_ALL:
target_lengths = [0 for _ in range(batch_size)]
else:
target_lengths = [(torch.randint(target_length // 2, target_length + 1, ()).item()
if vary_lengths else target_length) for _ in range(batch_size)]
if zero_mode == ZERO_SOME:
idxes = torch.randint(0, batch_size, (10,))
for i in idxes:
target_lengths[i] = 0
num_realval = np.random.randint(1, 16)
rv_x = torch.randn(gradcheck_input_size, dtype=fp, device=device)
tile_factors_rv = torch.randn(batch_size * input_length * num_realval // gradcheck_input_size + 1,
device=device)
targets_realvals = torch.randn(batch_size, input_length, num_realval, dtype=fp)
blank1 = np.random.randint(1, num_labels - 1)
x_full = ((x[:, None] * tile_factors[None, :]).view(-1)[:input_length * batch_size * num_labels]
.view(batch_size, input_length, num_labels))
rv_full = ((rv_x[:, None] * tile_factors_rv[None, :]).view(-1)[:input_length * batch_size * num_realval]
.view(batch_size, input_length, num_realval))
log_probs = torch.log_softmax(x_full, 2)
log_probs.requires_grad_()
rv_full.requires_grad_()
grad_out = torch.randn(batch_size, device='cpu', dtype=fp)
loss_native = custom_ctc_f(log_probs, targets, rv_full, targets_realvals, input_lengths, target_lengths, 1, 0, blank1, 'none', True)
grad_native = torch.autograd.grad(loss_native, [log_probs, rv_full], grad_out)
if torch.any(loss_native < 0) :
breakpoint()
log_probs.requires_grad_(False)
rv_full.requires_grad_(False)
log_probs = log_probs.cuda()
rv_full = rv_full.cuda()
log_probs.requires_grad_()
rv_full.requires_grad_()
targets = targets.cuda()
targets_realvals = targets_realvals.cuda()
loss_gpu = custom_ctc_f_gpu(log_probs, targets, rv_full, targets_realvals, input_lengths, target_lengths, 1, 0, blank1, 'none', True)
grad_gpu = torch.autograd.grad(loss_gpu, [log_probs, rv_full], grad_out.cuda())
#breakpoint()
assert torch.allclose(loss_native, loss_gpu.cpu(), rtol=1e-4, atol=1e-4)
print((grad_native[0] - grad_gpu[0].cpu()).abs().sum())
if not torch.allclose(grad_native[0], grad_gpu[0].cpu(), rtol=1e-2, atol=1e-2) :
breakpoint()
print((grad_native[1] - grad_gpu[1].cpu()).abs().sum())
assert torch.allclose(grad_native[1], grad_gpu[1].cpu(), rtol=1e-2, atol=1e-2)
if __name__ == '__main__' :
test_ctc_loss_custom('cpu:0')
for _ in range(100) :
test_ctc_loss_custom_gpu('cpu:0')
test_ctc_loss_custom_gpu('cpu:0', torch.double)
#test_ctc_loss_custom_gpu('cpu:0', torch.half)
print('test passed')