|
import collections |
|
import collections.abc |
|
import json |
|
import operator |
|
|
|
|
|
class Sentinel(object): |
|
'''Used to represent special values like UNK.''' |
|
|
|
__slots__ = ('name', ) |
|
|
|
def __init__(self, name): |
|
|
|
self.name = name |
|
|
|
def __repr__(self): |
|
|
|
return '<' + self.name + '>' |
|
|
|
def __lt__(self, other): |
|
|
|
if isinstance(other, IndexedSet.Sentinel): |
|
return self.name < other.name |
|
return True |
|
|
|
|
|
UNK = '<UNK>' |
|
BOS = '<BOS>' |
|
EOS = '<EOS>' |
|
|
|
|
|
class Vocab(collections.abc.Set): |
|
|
|
def __init__(self, iterable, special_elems=(UNK, BOS, EOS)): |
|
|
|
elements = list(special_elems) |
|
elements.extend(iterable) |
|
assert len(elements) == len(set(elements)) |
|
|
|
self.id_to_elem = {i: elem for i, elem in enumerate(elements)} |
|
self.elem_to_id = {elem: i for i, elem in enumerate(elements)} |
|
|
|
def __iter__(self): |
|
|
|
for i in range(len(self)): |
|
yield self.id_to_elem[i] |
|
|
|
def __contains__(self, value): |
|
|
|
return value in self.elem_to_id |
|
|
|
def __len__(self): |
|
|
|
return len(self.elem_to_id) |
|
|
|
def __getitem__(self, key): |
|
|
|
if isinstance(key, slice): |
|
raise TypeError('Slices not supported.') |
|
return self.id_to_elem[key] |
|
|
|
def index(self, value): |
|
|
|
try: |
|
return self.elem_to_id[value] |
|
except KeyError: |
|
return self.elem_to_id[UNK] |
|
|
|
def indices(self, values): |
|
|
|
return [self.index(value) for value in values] |
|
|
|
def __hash__(self): |
|
|
|
return id(self) |
|
|
|
@classmethod |
|
def load(self, in_path): |
|
return Vocab(json.load(open(in_path, encoding='utf8')), special_elems=()) |
|
|
|
def save(self, out_path): |
|
with open(out_path, 'w', encoding='utf8') as f: |
|
json.dump([self.id_to_elem[i] for i in range(len(self.id_to_elem))], f, ensure_ascii=False) |
|
|
|
|
|
class VocabBuilder: |
|
def __init__(self, min_freq=None, max_count=None): |
|
self.word_freq = collections.Counter() |
|
self.min_freq = min_freq |
|
self.max_count = max_count |
|
|
|
def add_word(self, word, count=1): |
|
self.word_freq[word] += count |
|
|
|
def finish(self, *args, **kwargs): |
|
|
|
eligible_words_and_freqs = self.word_freq.most_common(self.max_count) |
|
if self.min_freq is not None: |
|
for i, (word, freq) in enumerate(eligible_words_and_freqs): |
|
if freq < self.min_freq: |
|
eligible_words_and_freqs = eligible_words_and_freqs[:i] |
|
break |
|
|
|
return Vocab( |
|
(word for word, freq in sorted(eligible_words_and_freqs)), |
|
*args, **kwargs) |
|
|
|
def save(self, path): |
|
with open(path, "w", encoding='utf8') as f: |
|
json.dump(self.word_freq, f, ensure_ascii=False) |
|
|
|
def load(self, path): |
|
with open(path, "r", encoding='utf8') as f: |
|
self.word_freq = collections.Counter(json.load(f)) |
|
|