File size: 5,917 Bytes
d916065
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Natural Language Toolkit: Language Model Unit Tests
#
# Copyright (C) 2001-2023 NLTK Project
# Author: Ilia Kurenkov <[email protected]>
# URL: <https://www.nltk.org/>
# For license information, see LICENSE.TXT

import unittest
from collections import Counter
from timeit import timeit

from nltk.lm import Vocabulary


class NgramModelVocabularyTests(unittest.TestCase):
    """tests Vocabulary Class"""

    @classmethod
    def setUpClass(cls):
        cls.vocab = Vocabulary(
            ["z", "a", "b", "c", "f", "d", "e", "g", "a", "d", "b", "e", "w"],
            unk_cutoff=2,
        )

    def test_truthiness(self):
        self.assertTrue(self.vocab)

    def test_cutoff_value_set_correctly(self):
        self.assertEqual(self.vocab.cutoff, 2)

    def test_unable_to_change_cutoff(self):
        with self.assertRaises(AttributeError):
            self.vocab.cutoff = 3

    def test_cutoff_setter_checks_value(self):
        with self.assertRaises(ValueError) as exc_info:
            Vocabulary("abc", unk_cutoff=0)
        expected_error_msg = "Cutoff value cannot be less than 1. Got: 0"
        self.assertEqual(expected_error_msg, str(exc_info.exception))

    def test_counts_set_correctly(self):
        self.assertEqual(self.vocab.counts["a"], 2)
        self.assertEqual(self.vocab.counts["b"], 2)
        self.assertEqual(self.vocab.counts["c"], 1)

    def test_membership_check_respects_cutoff(self):
        # a was seen 2 times, so it should be considered part of the vocabulary
        self.assertTrue("a" in self.vocab)
        # "c" was seen once, it shouldn't be considered part of the vocab
        self.assertFalse("c" in self.vocab)
        # "z" was never seen at all, also shouldn't be considered in the vocab
        self.assertFalse("z" in self.vocab)

    def test_vocab_len_respects_cutoff(self):
        # Vocab size is the number of unique tokens that occur at least as often
        # as the cutoff value, plus 1 to account for unknown words.
        self.assertEqual(5, len(self.vocab))

    def test_vocab_iter_respects_cutoff(self):
        vocab_counts = ["a", "b", "c", "d", "e", "f", "g", "w", "z"]
        vocab_items = ["a", "b", "d", "e", "<UNK>"]

        self.assertCountEqual(vocab_counts, list(self.vocab.counts.keys()))
        self.assertCountEqual(vocab_items, list(self.vocab))

    def test_update_empty_vocab(self):
        empty = Vocabulary(unk_cutoff=2)
        self.assertEqual(len(empty), 0)
        self.assertFalse(empty)
        self.assertIn(empty.unk_label, empty)

        empty.update(list("abcde"))
        self.assertIn(empty.unk_label, empty)

    def test_lookup(self):
        self.assertEqual(self.vocab.lookup("a"), "a")
        self.assertEqual(self.vocab.lookup("c"), "<UNK>")

    def test_lookup_iterables(self):
        self.assertEqual(self.vocab.lookup(["a", "b"]), ("a", "b"))
        self.assertEqual(self.vocab.lookup(("a", "b")), ("a", "b"))
        self.assertEqual(self.vocab.lookup(("a", "c")), ("a", "<UNK>"))
        self.assertEqual(
            self.vocab.lookup(map(str, range(3))), ("<UNK>", "<UNK>", "<UNK>")
        )

    def test_lookup_empty_iterables(self):
        self.assertEqual(self.vocab.lookup(()), ())
        self.assertEqual(self.vocab.lookup([]), ())
        self.assertEqual(self.vocab.lookup(iter([])), ())
        self.assertEqual(self.vocab.lookup(n for n in range(0, 0)), ())

    def test_lookup_recursive(self):
        self.assertEqual(
            self.vocab.lookup([["a", "b"], ["a", "c"]]), (("a", "b"), ("a", "<UNK>"))
        )
        self.assertEqual(self.vocab.lookup([["a", "b"], "c"]), (("a", "b"), "<UNK>"))
        self.assertEqual(self.vocab.lookup([[[[["a", "b"]]]]]), ((((("a", "b"),),),),))

    def test_lookup_None(self):
        with self.assertRaises(TypeError):
            self.vocab.lookup(None)
        with self.assertRaises(TypeError):
            list(self.vocab.lookup([None, None]))

    def test_lookup_int(self):
        with self.assertRaises(TypeError):
            self.vocab.lookup(1)
        with self.assertRaises(TypeError):
            list(self.vocab.lookup([1, 2]))

    def test_lookup_empty_str(self):
        self.assertEqual(self.vocab.lookup(""), "<UNK>")

    def test_eqality(self):
        v1 = Vocabulary(["a", "b", "c"], unk_cutoff=1)
        v2 = Vocabulary(["a", "b", "c"], unk_cutoff=1)
        v3 = Vocabulary(["a", "b", "c"], unk_cutoff=1, unk_label="blah")
        v4 = Vocabulary(["a", "b"], unk_cutoff=1)

        self.assertEqual(v1, v2)
        self.assertNotEqual(v1, v3)
        self.assertNotEqual(v1, v4)

    def test_str(self):
        self.assertEqual(
            str(self.vocab), "<Vocabulary with cutoff=2 unk_label='<UNK>' and 5 items>"
        )

    def test_creation_with_counter(self):
        self.assertEqual(
            self.vocab,
            Vocabulary(
                Counter(
                    ["z", "a", "b", "c", "f", "d", "e", "g", "a", "d", "b", "e", "w"]
                ),
                unk_cutoff=2,
            ),
        )

    @unittest.skip(

        reason="Test is known to be flaky as it compares (runtime) performance."

    )
    def test_len_is_constant(self):
        # Given an obviously small and an obviously large vocabulary.
        small_vocab = Vocabulary("abcde")
        from nltk.corpus.europarl_raw import english

        large_vocab = Vocabulary(english.words())

        # If we time calling `len` on them.
        small_vocab_len_time = timeit("len(small_vocab)", globals=locals())
        large_vocab_len_time = timeit("len(large_vocab)", globals=locals())

        # The timing should be the same order of magnitude.
        self.assertAlmostEqual(small_vocab_len_time, large_vocab_len_time, places=1)