Spaces:
Runtime error
Runtime error
from __future__ import unicode_literals | |
from collections import defaultdict | |
import logging | |
logger = logging.getLogger(__name__) | |
class Vocab(object): | |
"""Defines a vocabulary object that will be used to numericalize a field. | |
Attributes: | |
freqs: A collections.Counter object holding the frequencies of tokens | |
in the data used to build the Vocab. | |
stoi: A collections.defaultdict instance mapping token strings to | |
numerical identifiers. | |
itos: A list of token strings indexed by their numerical identifiers. | |
""" | |
# TODO (@mttk): Populate classs with default values of special symbols | |
UNK = '<unk>' | |
def __init__(self, counter, max_size=None, min_freq=1, specials=['<unk>', '<pad>'], specials_first=True): | |
"""Create a Vocab object from a collections.Counter. | |
Arguments: | |
counter: collections.Counter object holding the frequencies of | |
each value found in the data. | |
max_size: The maximum size of the vocabulary, or None for no | |
maximum. Default: None. | |
min_freq: The minimum frequency needed to include a token in the | |
vocabulary. Values less than 1 will be set to 1. Default: 1. | |
specials: The list of special tokens (e.g., padding or eos) that | |
will be prepended to the vocabulary. Default: ['<unk'>, '<pad>'] | |
specials_first: Whether to add special tokens into the vocabulary at first. | |
If it is False, they are added into the vocabulary at last. | |
Default: True. | |
""" | |
self.freqs = counter | |
counter = counter.copy() | |
min_freq = max(min_freq, 1) | |
self.itos = list() | |
self.unk_index = None | |
if specials_first: | |
self.itos = list(specials) | |
# only extend max size if specials are prepended | |
max_size = None if max_size is None else max_size + len(specials) | |
# frequencies of special tokens are not counted when building vocabulary | |
# in frequency order | |
for tok in specials: | |
del counter[tok] | |
# sort by frequency, then alphabetically | |
words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) | |
words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) | |
for word, freq in words_and_frequencies: | |
if freq < min_freq or len(self.itos) == max_size: | |
break | |
self.itos.append(word) | |
if Vocab.UNK in specials: # hard-coded for now | |
unk_index = specials.index(Vocab.UNK) # position in list | |
# account for ordering of specials, set variable | |
self.unk_index = unk_index if specials_first else len(self.itos) + unk_index | |
self.stoi = defaultdict(self._default_unk_index) | |
else: | |
self.stoi = defaultdict() | |
if not specials_first: | |
self.itos.extend(list(specials)) | |
# stoi is simply a reverse dict for itos | |
self.stoi.update({tok: i for i, tok in enumerate(self.itos)}) | |
def _default_unk_index(self): | |
return self.unk_index | |
def __getitem__(self, token): | |
return self.stoi.get(token, self.stoi.get(Vocab.UNK)) | |
def __getstate__(self): | |
# avoid picking defaultdict | |
attrs = dict(self.__dict__) | |
# cast to regular dict | |
attrs['stoi'] = dict(self.stoi) | |
return attrs | |
def __setstate__(self, state): | |
if state.get("unk_index", None) is None: | |
stoi = defaultdict() | |
else: | |
stoi = defaultdict(self._default_unk_index) | |
stoi.update(state['stoi']) | |
state['stoi'] = stoi | |
self.__dict__.update(state) | |
def __eq__(self, other): | |
if self.freqs != other.freqs: | |
return False | |
if self.stoi != other.stoi: | |
return False | |
if self.itos != other.itos: | |
return False | |
return True | |
def __len__(self): | |
return len(self.itos) | |
def extend(self, v, sort=False): | |
words = sorted(v.itos) if sort else v.itos | |
for w in words: | |
if w not in self.stoi: | |
self.itos.append(w) | |
self.stoi[w] = len(self.itos) - 1 | |