File size: 5,549 Bytes
158b61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import unittest
from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLoss
import itertools
from copy import deepcopy
import torch
from torch.nn.functional import softmax
from onmt.tests.utils_for_tests import product_dict
class TestCopyGenerator(unittest.TestCase):
INIT_CASES = list(product_dict(
input_size=[172],
output_size=[319],
pad_idx=[0, 39],
))
PARAMS = list(product_dict(
batch_size=[1, 14],
max_seq_len=[23],
tgt_max_len=[50],
n_extra_words=[107]
))
@classmethod
def dummy_inputs(cls, params, init_case):
hidden = torch.randn((params["batch_size"] * params["tgt_max_len"],
init_case["input_size"]))
attn = torch.randn((params["batch_size"] * params["tgt_max_len"],
params["max_seq_len"]))
src_map = torch.randn((params["max_seq_len"], params["batch_size"],
params["n_extra_words"]))
return hidden, attn, src_map
@classmethod
def expected_shape(cls, params, init_case):
return params["tgt_max_len"] * params["batch_size"], \
init_case["output_size"] + params["n_extra_words"]
def test_copy_gen_forward_shape(self):
for params, init_case in itertools.product(
self.PARAMS, self.INIT_CASES):
cgen = CopyGenerator(**init_case)
dummy_in = self.dummy_inputs(params, init_case)
res = cgen(*dummy_in)
expected_shape = self.expected_shape(params, init_case)
self.assertEqual(res.shape, expected_shape, init_case.__str__())
def test_copy_gen_outp_has_no_prob_of_pad(self):
for params, init_case in itertools.product(
self.PARAMS, self.INIT_CASES):
cgen = CopyGenerator(**init_case)
dummy_in = self.dummy_inputs(params, init_case)
res = cgen(*dummy_in)
self.assertTrue(
res[:, init_case["pad_idx"]].allclose(torch.tensor(0.0)))
def test_copy_gen_trainable_params_update(self):
for params, init_case in itertools.product(
self.PARAMS, self.INIT_CASES):
cgen = CopyGenerator(**init_case)
trainable_params = {n: p for n, p in cgen.named_parameters()
if p.requires_grad}
assert len(trainable_params) > 0 # sanity check
old_weights = deepcopy(trainable_params)
dummy_in = self.dummy_inputs(params, init_case)
res = cgen(*dummy_in)
pretend_loss = res.sum()
pretend_loss.backward()
dummy_optim = torch.optim.SGD(trainable_params.values(), 1)
dummy_optim.step()
for param_name in old_weights.keys():
self.assertTrue(
trainable_params[param_name]
.ne(old_weights[param_name]).any(),
param_name + " " + init_case.__str__())
class TestCopyGeneratorLoss(unittest.TestCase):
INIT_CASES = list(product_dict(
vocab_size=[172],
unk_index=[0, 39],
ignore_index=[1, 17], # pad idx
force_copy=[True, False]
))
PARAMS = list(product_dict(
batch_size=[1, 14],
tgt_max_len=[50],
n_extra_words=[107]
))
@classmethod
def dummy_inputs(cls, params, init_case):
n_unique_src_words = 13
scores = torch.randn((params["batch_size"] * params["tgt_max_len"],
init_case["vocab_size"] + n_unique_src_words))
scores = softmax(scores, dim=1)
align = torch.randint(0, n_unique_src_words,
(params["batch_size"] * params["tgt_max_len"],))
target = torch.randint(0, init_case["vocab_size"],
(params["batch_size"] * params["tgt_max_len"],))
target[0] = init_case["unk_index"]
target[1] = init_case["ignore_index"]
return scores, align, target
@classmethod
def expected_shape(cls, params, init_case):
return (params["batch_size"] * params["tgt_max_len"],)
def test_copy_loss_forward_shape(self):
for params, init_case in itertools.product(
self.PARAMS, self.INIT_CASES):
loss = CopyGeneratorLoss(**init_case)
dummy_in = self.dummy_inputs(params, init_case)
res = loss(*dummy_in)
expected_shape = self.expected_shape(params, init_case)
self.assertEqual(res.shape, expected_shape, init_case.__str__())
def test_copy_loss_ignore_index_is_ignored(self):
for params, init_case in itertools.product(
self.PARAMS, self.INIT_CASES):
loss = CopyGeneratorLoss(**init_case)
scores, align, target = self.dummy_inputs(params, init_case)
res = loss(scores, align, target)
should_be_ignored = (target == init_case["ignore_index"]).nonzero(
as_tuple=False)
assert len(should_be_ignored) > 0 # otherwise not testing anything
self.assertTrue(res[should_be_ignored].allclose(torch.tensor(0.0)))
def test_copy_loss_output_range_is_positive(self):
for params, init_case in itertools.product(
self.PARAMS, self.INIT_CASES):
loss = CopyGeneratorLoss(**init_case)
dummy_in = self.dummy_inputs(params, init_case)
res = loss(*dummy_in)
self.assertTrue((res >= 0).all())
|