Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
# Copyright 2017-present, Facebook, Inc. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""Data processing/loading helpers.""" | |
import numpy as np | |
import logging | |
import unicodedata | |
from torch.utils.data import Dataset | |
from torch.utils.data.sampler import Sampler | |
from .vector import vectorize | |
logger = logging.getLogger(__name__) | |
# ------------------------------------------------------------------------------ | |
# Dictionary class for tokens. | |
# ------------------------------------------------------------------------------ | |
class Dictionary(object): | |
NULL = '<NULL>' | |
UNK = '<UNK>' | |
START = 2 | |
def normalize(token): | |
return unicodedata.normalize('NFD', token) | |
def __init__(self): | |
self.tok2ind = {self.NULL: 0, self.UNK: 1} | |
self.ind2tok = {0: self.NULL, 1: self.UNK} | |
def __len__(self): | |
return len(self.tok2ind) | |
def __iter__(self): | |
return iter(self.tok2ind) | |
def __contains__(self, key): | |
if type(key) == int: | |
return key in self.ind2tok | |
elif type(key) == str: | |
return self.normalize(key) in self.tok2ind | |
def __getitem__(self, key): | |
if type(key) == int: | |
return self.ind2tok.get(key, self.UNK) | |
if type(key) == str: | |
return self.tok2ind.get(self.normalize(key), | |
self.tok2ind.get(self.UNK)) | |
def __setitem__(self, key, item): | |
if type(key) == int and type(item) == str: | |
self.ind2tok[key] = item | |
elif type(key) == str and type(item) == int: | |
self.tok2ind[key] = item | |
else: | |
raise RuntimeError('Invalid (key, item) types.') | |
def add(self, token): | |
token = self.normalize(token) | |
if token not in self.tok2ind: | |
index = len(self.tok2ind) | |
self.tok2ind[token] = index | |
self.ind2tok[index] = token | |
def tokens(self): | |
"""Get dictionary tokens. | |
Return all the words indexed by this dictionary, except for special | |
tokens. | |
""" | |
tokens = [k for k in self.tok2ind.keys() | |
if k not in {'<NULL>', '<UNK>'}] | |
return tokens | |
# ------------------------------------------------------------------------------ | |
# PyTorch dataset class for SQuAD (and SQuAD-like) data. | |
# ------------------------------------------------------------------------------ | |
class ReaderDataset(Dataset): | |
def __init__(self, examples, model, single_answer=False): | |
self.model = model | |
self.examples = examples | |
self.single_answer = single_answer | |
def __len__(self): | |
return len(self.examples) | |
def __getitem__(self, index): | |
return vectorize(self.examples[index], self.model, self.single_answer) | |
def lengths(self): | |
return [(len(ex['document']), len(ex['question'])) | |
for ex in self.examples] | |
# ------------------------------------------------------------------------------ | |
# PyTorch sampler returning batched of sorted lengths (by doc and question). | |
# ------------------------------------------------------------------------------ | |
class SortedBatchSampler(Sampler): | |
def __init__(self, lengths, batch_size, shuffle=True): | |
self.lengths = lengths | |
self.batch_size = batch_size | |
self.shuffle = shuffle | |
def __iter__(self): | |
lengths = np.array( | |
[(-l[0], -l[1], np.random.random()) for l in self.lengths], | |
dtype=[('l1', np.int_), ('l2', np.int_), ('rand', np.float_)] | |
) | |
indices = np.argsort(lengths, order=('l1', 'l2', 'rand')) | |
batches = [indices[i:i + self.batch_size] | |
for i in range(0, len(indices), self.batch_size)] | |
if self.shuffle: | |
np.random.shuffle(batches) | |
return iter([i for batch in batches for i in batch]) | |
def __len__(self): | |
return len(self.lengths) | |