from __future__ import unicode_literals from collections import defaultdict import logging logger = logging.getLogger(__name__) class Vocab(object): """Defines a vocabulary object that will be used to numericalize a field. Attributes: freqs: A collections.Counter object holding the frequencies of tokens in the data used to build the Vocab. stoi: A collections.defaultdict instance mapping token strings to numerical identifiers. itos: A list of token strings indexed by their numerical identifiers. """ # TODO (@mttk): Populate classs with default values of special symbols UNK = '' def __init__(self, counter, max_size=None, min_freq=1, specials=['', ''], specials_first=True): """Create a Vocab object from a collections.Counter. Arguments: counter: collections.Counter object holding the frequencies of each value found in the data. max_size: The maximum size of the vocabulary, or None for no maximum. Default: None. min_freq: The minimum frequency needed to include a token in the vocabulary. Values less than 1 will be set to 1. Default: 1. specials: The list of special tokens (e.g., padding or eos) that will be prepended to the vocabulary. Default: [', ''] specials_first: Whether to add special tokens into the vocabulary at first. If it is False, they are added into the vocabulary at last. Default: True. """ self.freqs = counter counter = counter.copy() min_freq = max(min_freq, 1) self.itos = list() self.unk_index = None if specials_first: self.itos = list(specials) # only extend max size if specials are prepended max_size = None if max_size is None else max_size + len(specials) # frequencies of special tokens are not counted when building vocabulary # in frequency order for tok in specials: del counter[tok] # sort by frequency, then alphabetically words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) for word, freq in words_and_frequencies: if freq < min_freq or len(self.itos) == max_size: break self.itos.append(word) if Vocab.UNK in specials: # hard-coded for now unk_index = specials.index(Vocab.UNK) # position in list # account for ordering of specials, set variable self.unk_index = unk_index if specials_first else len(self.itos) + unk_index self.stoi = defaultdict(self._default_unk_index) else: self.stoi = defaultdict() if not specials_first: self.itos.extend(list(specials)) # stoi is simply a reverse dict for itos self.stoi.update({tok: i for i, tok in enumerate(self.itos)}) def _default_unk_index(self): return self.unk_index def __getitem__(self, token): return self.stoi.get(token, self.stoi.get(Vocab.UNK)) def __getstate__(self): # avoid picking defaultdict attrs = dict(self.__dict__) # cast to regular dict attrs['stoi'] = dict(self.stoi) return attrs def __setstate__(self, state): if state.get("unk_index", None) is None: stoi = defaultdict() else: stoi = defaultdict(self._default_unk_index) stoi.update(state['stoi']) state['stoi'] = stoi self.__dict__.update(state) def __eq__(self, other): if self.freqs != other.freqs: return False if self.stoi != other.stoi: return False if self.itos != other.itos: return False return True def __len__(self): return len(self.itos) def extend(self, v, sort=False): words = sorted(v.itos) if sort else v.itos for w in words: if w not in self.stoi: self.itos.append(w) self.stoi[w] = len(self.itos) - 1