Spaces:
Runtime error
Runtime error
import torch | |
from data.field.mini_torchtext.field import RawField | |
from data.field.mini_torchtext.vocab import Vocab | |
from collections import Counter | |
class LabelField(RawField): | |
def __self__(self, preprocessing): | |
super(LabelField, self).__init__(preprocessing=preprocessing) | |
self.vocab = None | |
def build_vocab(self, *args, **kwargs): | |
sources = [] | |
for arg in args: | |
if isinstance(arg, torch.utils.data.Dataset): | |
sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self] | |
else: | |
sources.append(arg) | |
counter = Counter() | |
for data in sources: | |
for x in data: | |
counter.update(x) | |
self.vocab = Vocab(counter, specials=[]) | |
def process(self, example, device=None): | |
tensor, lengths = self.numericalize(example, device=device) | |
return tensor, lengths | |
def numericalize(self, example, device=None): | |
example = [self.vocab.stoi[x] + 1 for x in example] | |
length = torch.LongTensor([len(example)], device=device).squeeze(0) | |
tensor = torch.LongTensor(example, device=device) | |
return tensor, length | |