|
|
|
import torch |
|
import custom_ctc |
|
import custom_ctc_gpu |
|
|
|
import numpy as np |
|
|
|
import torch.nn.functional as F |
|
from torch.autograd import gradcheck |
|
|
|
custom_ctc_f = custom_ctc.CustomCTCLossFunction.apply |
|
custom_ctc_f_gpu = custom_ctc_gpu.CustomCTCLossFunction.apply |
|
|
|
def test_ctc_loss_custom(device): |
|
batch_size = 64 |
|
num_labels = 101 |
|
target_length = 15 |
|
gradcheck_input_size = 10 |
|
|
|
ZERO_NONE = 0 |
|
ZERO_SOME = 1 |
|
ZERO_ALL = 2 |
|
|
|
|
|
tests = [(150, False, ZERO_NONE), |
|
(150, True, ZERO_NONE), |
|
(50, True, ZERO_SOME), |
|
(50, True, ZERO_ALL)] |
|
|
|
tests += [(50, False, ZERO_NONE), |
|
(50, True, ZERO_NONE), |
|
(150, True, ZERO_SOME), |
|
(150, True, ZERO_ALL)] |
|
|
|
for input_length, vary_lengths, zero_mode in tests: |
|
targets = torch.randint(1, num_labels, (batch_size, target_length), |
|
device=device, dtype=torch.long) |
|
x = torch.randn(gradcheck_input_size, dtype=torch.double, device=device, requires_grad=True) |
|
tile_factors = torch.randn(input_length * batch_size * num_labels // gradcheck_input_size + 1, |
|
device=device) |
|
input_lengths = [(torch.randint(input_length // 2, input_length + 1, ()).item() |
|
if vary_lengths or i == 0 else input_length) for i in range(batch_size)] |
|
if zero_mode == ZERO_ALL: |
|
target_lengths = [0 for _ in range(batch_size)] |
|
else: |
|
target_lengths = [(torch.randint(target_length // 2, target_length + 1, ()).item() |
|
if vary_lengths else target_length) for _ in range(batch_size)] |
|
if zero_mode == ZERO_SOME: |
|
idxes = torch.randint(0, batch_size, (10,)) |
|
for i in idxes: |
|
target_lengths[i] = 0 |
|
|
|
num_realval = np.random.randint(1, 16) |
|
rv_x = torch.randn(gradcheck_input_size, dtype=torch.double, device=device, requires_grad=True) |
|
tile_factors_rv = torch.randn(batch_size * input_length * num_realval // gradcheck_input_size + 1, |
|
device=device) |
|
|
|
targets_realvals = torch.randn(batch_size, input_length, num_realval, dtype=torch.double) |
|
|
|
blank1 = np.random.randint(1, num_labels - 1) |
|
|
|
def ctc_after_softmax(x, rv): |
|
x_full = ((x[:, None] * tile_factors[None, :]).view(-1)[:input_length * batch_size * num_labels] |
|
.view(batch_size, input_length, num_labels)) |
|
rv_full = ((rv[:, None] * tile_factors_rv[None, :]).view(-1)[:input_length * batch_size * num_realval] |
|
.view(batch_size, input_length, num_realval)) |
|
log_probs = torch.log_softmax(x_full, 2) |
|
return custom_ctc_f(log_probs, targets, rv_full, targets_realvals, input_lengths, target_lengths, 2.2, 0, blank1, 'mean', True) |
|
|
|
gradcheck(ctc_after_softmax, [x, rv_x]) |
|
|
|
|
|
def test_ctc_loss_custom_gpu(device, fp = torch.float32): |
|
print('testing GPU gradient for %s' % str(fp)) |
|
batch_size = 64 |
|
num_labels = 101 |
|
target_length = 15 |
|
gradcheck_input_size = 10 |
|
|
|
ZERO_NONE = 0 |
|
ZERO_SOME = 1 |
|
ZERO_ALL = 2 |
|
|
|
|
|
tests = [(150, False, ZERO_NONE), |
|
(150, True, ZERO_NONE), |
|
(50, True, ZERO_SOME), |
|
(50, True, ZERO_ALL)] |
|
|
|
tests += [(50, False, ZERO_NONE), |
|
(50, True, ZERO_NONE), |
|
(150, True, ZERO_SOME), |
|
(150, True, ZERO_ALL)] |
|
|
|
for input_length, vary_lengths, zero_mode in tests: |
|
targets = torch.randint(1, num_labels, (batch_size, target_length), |
|
device=device, dtype=torch.long) |
|
x = torch.randn(gradcheck_input_size, dtype=fp, device=device) |
|
tile_factors = torch.randn(input_length * batch_size * num_labels // gradcheck_input_size + 1, |
|
device=device) |
|
input_lengths = [(torch.randint(input_length // 2, input_length + 1, ()).item() |
|
if vary_lengths or i == 0 else input_length) for i in range(batch_size)] |
|
if zero_mode == ZERO_ALL: |
|
target_lengths = [0 for _ in range(batch_size)] |
|
else: |
|
target_lengths = [(torch.randint(target_length // 2, target_length + 1, ()).item() |
|
if vary_lengths else target_length) for _ in range(batch_size)] |
|
if zero_mode == ZERO_SOME: |
|
idxes = torch.randint(0, batch_size, (10,)) |
|
for i in idxes: |
|
target_lengths[i] = 0 |
|
|
|
num_realval = np.random.randint(1, 16) |
|
rv_x = torch.randn(gradcheck_input_size, dtype=fp, device=device) |
|
tile_factors_rv = torch.randn(batch_size * input_length * num_realval // gradcheck_input_size + 1, |
|
device=device) |
|
|
|
targets_realvals = torch.randn(batch_size, input_length, num_realval, dtype=fp) |
|
|
|
blank1 = np.random.randint(1, num_labels - 1) |
|
|
|
x_full = ((x[:, None] * tile_factors[None, :]).view(-1)[:input_length * batch_size * num_labels] |
|
.view(batch_size, input_length, num_labels)) |
|
rv_full = ((rv_x[:, None] * tile_factors_rv[None, :]).view(-1)[:input_length * batch_size * num_realval] |
|
.view(batch_size, input_length, num_realval)) |
|
log_probs = torch.log_softmax(x_full, 2) |
|
log_probs.requires_grad_() |
|
rv_full.requires_grad_() |
|
grad_out = torch.randn(batch_size, device='cpu', dtype=fp) |
|
loss_native = custom_ctc_f(log_probs, targets, rv_full, targets_realvals, input_lengths, target_lengths, 1, 0, blank1, 'none', True) |
|
grad_native = torch.autograd.grad(loss_native, [log_probs, rv_full], grad_out) |
|
if torch.any(loss_native < 0) : |
|
breakpoint() |
|
|
|
log_probs.requires_grad_(False) |
|
rv_full.requires_grad_(False) |
|
log_probs = log_probs.cuda() |
|
rv_full = rv_full.cuda() |
|
log_probs.requires_grad_() |
|
rv_full.requires_grad_() |
|
targets = targets.cuda() |
|
targets_realvals = targets_realvals.cuda() |
|
|
|
loss_gpu = custom_ctc_f_gpu(log_probs, targets, rv_full, targets_realvals, input_lengths, target_lengths, 1, 0, blank1, 'none', True) |
|
grad_gpu = torch.autograd.grad(loss_gpu, [log_probs, rv_full], grad_out.cuda()) |
|
|
|
assert torch.allclose(loss_native, loss_gpu.cpu(), rtol=1e-4, atol=1e-4) |
|
print((grad_native[0] - grad_gpu[0].cpu()).abs().sum()) |
|
if not torch.allclose(grad_native[0], grad_gpu[0].cpu(), rtol=1e-2, atol=1e-2) : |
|
breakpoint() |
|
print((grad_native[1] - grad_gpu[1].cpu()).abs().sum()) |
|
assert torch.allclose(grad_native[1], grad_gpu[1].cpu(), rtol=1e-2, atol=1e-2) |
|
|
|
if __name__ == '__main__' : |
|
test_ctc_loss_custom('cpu:0') |
|
for _ in range(100) : |
|
test_ctc_loss_custom_gpu('cpu:0') |
|
test_ctc_loss_custom_gpu('cpu:0', torch.double) |
|
|
|
print('test passed') |
|
|