File size: 7,295 Bytes
9dce458 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import torch
import custom_ctc
import custom_ctc_gpu
import numpy as np
import torch.nn.functional as F
from torch.autograd import gradcheck
custom_ctc_f = custom_ctc.CustomCTCLossFunction.apply
custom_ctc_f_gpu = custom_ctc_gpu.CustomCTCLossFunction.apply
def test_ctc_loss_custom(device):
batch_size = 64
num_labels = 101
target_length = 15
gradcheck_input_size = 10
ZERO_NONE = 0
ZERO_SOME = 1
ZERO_ALL = 2
# input_length, vary_lengths, zero_lengths
tests = [(150, False, ZERO_NONE),
(150, True, ZERO_NONE),
(50, True, ZERO_SOME),
(50, True, ZERO_ALL)]
tests += [(50, False, ZERO_NONE),
(50, True, ZERO_NONE),
(150, True, ZERO_SOME),
(150, True, ZERO_ALL)]
for input_length, vary_lengths, zero_mode in tests:
targets = torch.randint(1, num_labels, (batch_size, target_length),
device=device, dtype=torch.long)
x = torch.randn(gradcheck_input_size, dtype=torch.double, device=device, requires_grad=True)
tile_factors = torch.randn(input_length * batch_size * num_labels // gradcheck_input_size + 1,
device=device)
input_lengths = [(torch.randint(input_length // 2, input_length + 1, ()).item()
if vary_lengths or i == 0 else input_length) for i in range(batch_size)]
if zero_mode == ZERO_ALL:
target_lengths = [0 for _ in range(batch_size)]
else:
target_lengths = [(torch.randint(target_length // 2, target_length + 1, ()).item()
if vary_lengths else target_length) for _ in range(batch_size)]
if zero_mode == ZERO_SOME:
idxes = torch.randint(0, batch_size, (10,))
for i in idxes:
target_lengths[i] = 0
num_realval = np.random.randint(1, 16)
rv_x = torch.randn(gradcheck_input_size, dtype=torch.double, device=device, requires_grad=True)
tile_factors_rv = torch.randn(batch_size * input_length * num_realval // gradcheck_input_size + 1,
device=device)
targets_realvals = torch.randn(batch_size, input_length, num_realval, dtype=torch.double)
blank1 = np.random.randint(1, num_labels - 1)
def ctc_after_softmax(x, rv):
x_full = ((x[:, None] * tile_factors[None, :]).view(-1)[:input_length * batch_size * num_labels]
.view(batch_size, input_length, num_labels))
rv_full = ((rv[:, None] * tile_factors_rv[None, :]).view(-1)[:input_length * batch_size * num_realval]
.view(batch_size, input_length, num_realval))
log_probs = torch.log_softmax(x_full, 2)
return custom_ctc_f(log_probs, targets, rv_full, targets_realvals, input_lengths, target_lengths, 2.2, 0, blank1, 'mean', True)
gradcheck(ctc_after_softmax, [x, rv_x])
def test_ctc_loss_custom_gpu(device, fp = torch.float32):
print('testing GPU gradient for %s' % str(fp))
batch_size = 64
num_labels = 101
target_length = 15
gradcheck_input_size = 10
ZERO_NONE = 0
ZERO_SOME = 1
ZERO_ALL = 2
# input_length, vary_lengths, zero_lengths
tests = [(150, False, ZERO_NONE),
(150, True, ZERO_NONE),
(50, True, ZERO_SOME),
(50, True, ZERO_ALL)]
tests += [(50, False, ZERO_NONE),
(50, True, ZERO_NONE),
(150, True, ZERO_SOME),
(150, True, ZERO_ALL)]
for input_length, vary_lengths, zero_mode in tests:
targets = torch.randint(1, num_labels, (batch_size, target_length),
device=device, dtype=torch.long)
x = torch.randn(gradcheck_input_size, dtype=fp, device=device)
tile_factors = torch.randn(input_length * batch_size * num_labels // gradcheck_input_size + 1,
device=device)
input_lengths = [(torch.randint(input_length // 2, input_length + 1, ()).item()
if vary_lengths or i == 0 else input_length) for i in range(batch_size)]
if zero_mode == ZERO_ALL:
target_lengths = [0 for _ in range(batch_size)]
else:
target_lengths = [(torch.randint(target_length // 2, target_length + 1, ()).item()
if vary_lengths else target_length) for _ in range(batch_size)]
if zero_mode == ZERO_SOME:
idxes = torch.randint(0, batch_size, (10,))
for i in idxes:
target_lengths[i] = 0
num_realval = np.random.randint(1, 16)
rv_x = torch.randn(gradcheck_input_size, dtype=fp, device=device)
tile_factors_rv = torch.randn(batch_size * input_length * num_realval // gradcheck_input_size + 1,
device=device)
targets_realvals = torch.randn(batch_size, input_length, num_realval, dtype=fp)
blank1 = np.random.randint(1, num_labels - 1)
x_full = ((x[:, None] * tile_factors[None, :]).view(-1)[:input_length * batch_size * num_labels]
.view(batch_size, input_length, num_labels))
rv_full = ((rv_x[:, None] * tile_factors_rv[None, :]).view(-1)[:input_length * batch_size * num_realval]
.view(batch_size, input_length, num_realval))
log_probs = torch.log_softmax(x_full, 2)
log_probs.requires_grad_()
rv_full.requires_grad_()
grad_out = torch.randn(batch_size, device='cpu', dtype=fp)
loss_native = custom_ctc_f(log_probs, targets, rv_full, targets_realvals, input_lengths, target_lengths, 1, 0, blank1, 'none', True)
grad_native = torch.autograd.grad(loss_native, [log_probs, rv_full], grad_out)
if torch.any(loss_native < 0) :
breakpoint()
log_probs.requires_grad_(False)
rv_full.requires_grad_(False)
log_probs = log_probs.cuda()
rv_full = rv_full.cuda()
log_probs.requires_grad_()
rv_full.requires_grad_()
targets = targets.cuda()
targets_realvals = targets_realvals.cuda()
loss_gpu = custom_ctc_f_gpu(log_probs, targets, rv_full, targets_realvals, input_lengths, target_lengths, 1, 0, blank1, 'none', True)
grad_gpu = torch.autograd.grad(loss_gpu, [log_probs, rv_full], grad_out.cuda())
#breakpoint()
assert torch.allclose(loss_native, loss_gpu.cpu(), rtol=1e-4, atol=1e-4)
print((grad_native[0] - grad_gpu[0].cpu()).abs().sum())
if not torch.allclose(grad_native[0], grad_gpu[0].cpu(), rtol=1e-2, atol=1e-2) :
breakpoint()
print((grad_native[1] - grad_gpu[1].cpu()).abs().sum())
assert torch.allclose(grad_native[1], grad_gpu[1].cpu(), rtol=1e-2, atol=1e-2)
if __name__ == '__main__' :
test_ctc_loss_custom('cpu:0')
for _ in range(100) :
test_ctc_loss_custom_gpu('cpu:0')
test_ctc_loss_custom_gpu('cpu:0', torch.double)
#test_ctc_loss_custom_gpu('cpu:0', torch.half)
print('test passed')
|