File size: 3,361 Bytes
d758c99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import collections
import collections.abc
import json
import operator


class Sentinel(object):
    '''Used to represent special values like UNK.'''
    # pylint: disable=too-few-public-methods
    __slots__ = ('name', )

    def __init__(self, name):
        # type: (str) -> None
        self.name = name

    def __repr__(self):
        # type: () -> str
        return '<' + self.name + '>'

    def __lt__(self, other):
        # type: (object) -> bool
        if isinstance(other, IndexedSet.Sentinel):
            return self.name < other.name
        return True


UNK = '<UNK>'
BOS = '<BOS>'
EOS = '<EOS>'


class Vocab(collections.abc.Set):

    def __init__(self, iterable, special_elems=(UNK, BOS, EOS)):
        # type: (Iterable[T]) -> None
        elements = list(special_elems)
        elements.extend(iterable)
        assert len(elements) == len(set(elements))

        self.id_to_elem = {i: elem for i, elem in enumerate(elements)}
        self.elem_to_id = {elem: i for i, elem in enumerate(elements)}

    def __iter__(self):
        # type: () -> Iterator[Union[T, IndexedSet.Sentinel]]
        for i in range(len(self)):
            yield self.id_to_elem[i]

    def __contains__(self, value):
        # type: (object) -> bool
        return value in self.elem_to_id

    def __len__(self):
        # type: () -> int
        return len(self.elem_to_id)

    def __getitem__(self, key):
        # type: (int) -> Union[T, IndexedSet.Sentinel]
        if isinstance(key, slice):
            raise TypeError('Slices not supported.')
        return self.id_to_elem[key]

    def index(self, value):
        # type: (T) -> int
        try:
            return self.elem_to_id[value]
        except KeyError:
            return self.elem_to_id[UNK]

    def indices(self, values):
        # type: (Iterable[T]) -> List[int]
        return [self.index(value) for value in values]

    def __hash__(self):
        # type: () -> int
        return id(self)

    @classmethod
    def load(self, in_path):
        return Vocab(json.load(open(in_path)), special_elems=())

    def save(self, out_path):
        with open(out_path, 'w') as f:
            json.dump([self.id_to_elem[i] for i in range(len(self.id_to_elem))],  f)


class VocabBuilder:
    def __init__(self, min_freq=None, max_count=None):
        self.word_freq = collections.Counter()
        self.min_freq = min_freq
        self.max_count = max_count

    def add_word(self, word, count=1):
        self.word_freq[word] += count

    def finish(self, *args, **kwargs):
        # Select the `max_count` most frequent words. If `max_count` is None, then choose all of the words.
        eligible_words_and_freqs = self.word_freq.most_common(self.max_count)
        if self.min_freq is not None:
            for i, (word, freq) in enumerate(eligible_words_and_freqs):
                if freq < self.min_freq:
                    eligible_words_and_freqs = eligible_words_and_freqs[:i]
                    break

        return Vocab(
                    (word for word, freq in sorted(eligible_words_and_freqs)), 
                    *args, **kwargs)
    
    def save(self, path):
        with open(path, "w") as f:
            json.dump(self.word_freq, f)
    
    def load(self, path):
        with open(path, "r") as f:
            self.word_freq = collections.Counter(json.load(f))