File size: 5,549 Bytes
158b61b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import unittest
from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLoss

import itertools
from copy import deepcopy

import torch
from torch.nn.functional import softmax

from onmt.tests.utils_for_tests import product_dict


class TestCopyGenerator(unittest.TestCase):
    INIT_CASES = list(product_dict(
        input_size=[172],
        output_size=[319],
        pad_idx=[0, 39],
    ))
    PARAMS = list(product_dict(
        batch_size=[1, 14],
        max_seq_len=[23],
        tgt_max_len=[50],
        n_extra_words=[107]
    ))

    @classmethod
    def dummy_inputs(cls, params, init_case):
        hidden = torch.randn((params["batch_size"] * params["tgt_max_len"],
                              init_case["input_size"]))
        attn = torch.randn((params["batch_size"] * params["tgt_max_len"],
                            params["max_seq_len"]))
        src_map = torch.randn((params["max_seq_len"], params["batch_size"],
                               params["n_extra_words"]))
        return hidden, attn, src_map

    @classmethod
    def expected_shape(cls, params, init_case):
        return params["tgt_max_len"] * params["batch_size"], \
               init_case["output_size"] + params["n_extra_words"]

    def test_copy_gen_forward_shape(self):
        for params, init_case in itertools.product(
                self.PARAMS, self.INIT_CASES):
            cgen = CopyGenerator(**init_case)
            dummy_in = self.dummy_inputs(params, init_case)
            res = cgen(*dummy_in)
            expected_shape = self.expected_shape(params, init_case)
            self.assertEqual(res.shape, expected_shape, init_case.__str__())

    def test_copy_gen_outp_has_no_prob_of_pad(self):
        for params, init_case in itertools.product(
                self.PARAMS, self.INIT_CASES):
            cgen = CopyGenerator(**init_case)
            dummy_in = self.dummy_inputs(params, init_case)
            res = cgen(*dummy_in)
            self.assertTrue(
                res[:, init_case["pad_idx"]].allclose(torch.tensor(0.0)))

    def test_copy_gen_trainable_params_update(self):
        for params, init_case in itertools.product(
                self.PARAMS, self.INIT_CASES):
            cgen = CopyGenerator(**init_case)
            trainable_params = {n: p for n, p in cgen.named_parameters()
                                if p.requires_grad}
            assert len(trainable_params) > 0  # sanity check
            old_weights = deepcopy(trainable_params)
            dummy_in = self.dummy_inputs(params, init_case)
            res = cgen(*dummy_in)
            pretend_loss = res.sum()
            pretend_loss.backward()
            dummy_optim = torch.optim.SGD(trainable_params.values(), 1)
            dummy_optim.step()
            for param_name in old_weights.keys():
                self.assertTrue(
                    trainable_params[param_name]
                    .ne(old_weights[param_name]).any(),
                    param_name + " " + init_case.__str__())


class TestCopyGeneratorLoss(unittest.TestCase):
    INIT_CASES = list(product_dict(
        vocab_size=[172],
        unk_index=[0, 39],
        ignore_index=[1, 17],  # pad idx
        force_copy=[True, False]
    ))
    PARAMS = list(product_dict(
        batch_size=[1, 14],
        tgt_max_len=[50],
        n_extra_words=[107]
    ))

    @classmethod
    def dummy_inputs(cls, params, init_case):
        n_unique_src_words = 13
        scores = torch.randn((params["batch_size"] * params["tgt_max_len"],
                              init_case["vocab_size"] + n_unique_src_words))
        scores = softmax(scores, dim=1)
        align = torch.randint(0, n_unique_src_words,
                              (params["batch_size"] * params["tgt_max_len"],))
        target = torch.randint(0, init_case["vocab_size"],
                               (params["batch_size"] * params["tgt_max_len"],))
        target[0] = init_case["unk_index"]
        target[1] = init_case["ignore_index"]
        return scores, align, target

    @classmethod
    def expected_shape(cls, params, init_case):
        return (params["batch_size"] * params["tgt_max_len"],)

    def test_copy_loss_forward_shape(self):
        for params, init_case in itertools.product(
                self.PARAMS, self.INIT_CASES):
            loss = CopyGeneratorLoss(**init_case)
            dummy_in = self.dummy_inputs(params, init_case)
            res = loss(*dummy_in)
            expected_shape = self.expected_shape(params, init_case)
            self.assertEqual(res.shape, expected_shape, init_case.__str__())

    def test_copy_loss_ignore_index_is_ignored(self):
        for params, init_case in itertools.product(
                self.PARAMS, self.INIT_CASES):
            loss = CopyGeneratorLoss(**init_case)
            scores, align, target = self.dummy_inputs(params, init_case)
            res = loss(scores, align, target)
            should_be_ignored = (target == init_case["ignore_index"]).nonzero(
                as_tuple=False)
            assert len(should_be_ignored) > 0  # otherwise not testing anything
            self.assertTrue(res[should_be_ignored].allclose(torch.tensor(0.0)))

    def test_copy_loss_output_range_is_positive(self):
        for params, init_case in itertools.product(
                self.PARAMS, self.INIT_CASES):
            loss = CopyGeneratorLoss(**init_case)
            dummy_in = self.dummy_inputs(params, init_case)
            res = loss(*dummy_in)
            self.assertTrue((res >= 0).all())