Spaces:
Runtime error
Runtime error
import torch | |
from data.field.mini_torchtext.field import Field as TorchTextField | |
from collections import Counter, OrderedDict | |
# small change of vocab building to correspond to our version of Dataset | |
class Field(TorchTextField): | |
def build_vocab(self, *args, **kwargs): | |
counter = Counter() | |
sources = [] | |
for arg in args: | |
if isinstance(arg, torch.utils.data.Dataset): | |
sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self] | |
else: | |
sources.append(arg) | |
for data in sources: | |
for x in data: | |
if not self.sequential: | |
x = [x] | |
counter.update(x) | |
specials = list( | |
OrderedDict.fromkeys( | |
tok | |
for tok in [self.unk_token, self.pad_token, self.init_token, self.eos_token] + kwargs.pop("specials", []) | |
if tok is not None | |
) | |
) | |
self.vocab = self.vocab_cls(counter, specials=specials, **kwargs) | |
def process(self, example, device=None): | |
if self.include_lengths: | |
example = example, len(example) | |
tensor = self.numericalize(example, device=device) | |
return tensor | |
def numericalize(self, ex, device=None): | |
if self.include_lengths and not isinstance(ex, tuple): | |
raise ValueError("Field has include_lengths set to True, but input data is not a tuple of (data batch, batch lengths).") | |
if isinstance(ex, tuple): | |
ex, lengths = ex | |
lengths = torch.tensor(lengths, dtype=self.dtype, device=device) | |
if self.use_vocab: | |
if self.sequential: | |
ex = [self.vocab.stoi[x] for x in ex] | |
else: | |
ex = self.vocab.stoi[ex] | |
if self.postprocessing is not None: | |
ex = self.postprocessing(ex, self.vocab) | |
else: | |
numericalization_func = self.dtypes[self.dtype] | |
if not self.sequential: | |
ex = numericalization_func(ex) if isinstance(ex, str) else ex | |
if self.postprocessing is not None: | |
ex = self.postprocessing(ex, None) | |
var = torch.tensor(ex, dtype=self.dtype, device=device) | |
if self.sequential and not self.batch_first: | |
var.t_() | |
if self.sequential: | |
var = var.contiguous() | |
if self.include_lengths: | |
return var, lengths | |
return var | |