File size: 3,472 Bytes
d758c99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import collections
import collections.abc
import json
import operator
class Sentinel(object):
'''Used to represent special values like UNK.'''
# pylint: disable=too-few-public-methods
__slots__ = ('name', )
def __init__(self, name):
# type: (str) -> None
self.name = name
def __repr__(self):
# type: () -> str
return '<' + self.name + '>'
def __lt__(self, other):
# type: (object) -> bool
if isinstance(other, IndexedSet.Sentinel):
return self.name < other.name
return True
UNK = '<UNK>'
BOS = '<BOS>'
EOS = '<EOS>'
class Vocab(collections.abc.Set):
def __init__(self, iterable, special_elems=(UNK, BOS, EOS)):
# type: (Iterable[T]) -> None
elements = list(special_elems)
elements.extend(iterable)
assert len(elements) == len(set(elements))
self.id_to_elem = {i: elem for i, elem in enumerate(elements)}
self.elem_to_id = {elem: i for i, elem in enumerate(elements)}
def __iter__(self):
# type: () -> Iterator[Union[T, IndexedSet.Sentinel]]
for i in range(len(self)):
yield self.id_to_elem[i]
def __contains__(self, value):
# type: (object) -> bool
return value in self.elem_to_id
def __len__(self):
# type: () -> int
return len(self.elem_to_id)
def __getitem__(self, key):
# type: (int) -> Union[T, IndexedSet.Sentinel]
if isinstance(key, slice):
raise TypeError('Slices not supported.')
return self.id_to_elem[key]
def index(self, value):
# type: (T) -> int
try:
return self.elem_to_id[value]
except KeyError:
return self.elem_to_id[UNK]
def indices(self, values):
# type: (Iterable[T]) -> List[int]
return [self.index(value) for value in values]
def __hash__(self):
# type: () -> int
return id(self)
@classmethod
def load(self, in_path):
return Vocab(json.load(open(in_path, encoding='utf8')), special_elems=())
def save(self, out_path):
with open(out_path, 'w', encoding='utf8') as f:
json.dump([self.id_to_elem[i] for i in range(len(self.id_to_elem))], f, ensure_ascii=False)
class VocabBuilder:
def __init__(self, min_freq=None, max_count=None):
self.word_freq = collections.Counter()
self.min_freq = min_freq
self.max_count = max_count
def add_word(self, word, count=1):
self.word_freq[word] += count
def finish(self, *args, **kwargs):
# Select the `max_count` most frequent words. If `max_count` is None, then choose all of the words.
eligible_words_and_freqs = self.word_freq.most_common(self.max_count)
if self.min_freq is not None:
for i, (word, freq) in enumerate(eligible_words_and_freqs):
if freq < self.min_freq:
eligible_words_and_freqs = eligible_words_and_freqs[:i]
break
return Vocab(
(word for word, freq in sorted(eligible_words_and_freqs)),
*args, **kwargs)
def save(self, path):
with open(path, "w", encoding='utf8') as f:
json.dump(self.word_freq, f, ensure_ascii=False)
def load(self, path):
with open(path, "r", encoding='utf8') as f:
self.word_freq = collections.Counter(json.load(f))
|