Spaces:
Runtime error
Runtime error
File size: 1,213 Bytes
8044721 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import torch
from data.field.mini_torchtext.field import RawField
from data.field.mini_torchtext.vocab import Vocab
from collections import Counter
class LabelField(RawField):
def __self__(self, preprocessing):
super(LabelField, self).__init__(preprocessing=preprocessing)
self.vocab = None
def build_vocab(self, *args, **kwargs):
sources = []
for arg in args:
if isinstance(arg, torch.utils.data.Dataset):
sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self]
else:
sources.append(arg)
counter = Counter()
for data in sources:
for x in data:
counter.update(x)
self.vocab = Vocab(counter, specials=[])
def process(self, example, device=None):
tensor, lengths = self.numericalize(example, device=device)
return tensor, lengths
def numericalize(self, example, device=None):
example = [self.vocab.stoi[x] + 1 for x in example]
length = torch.LongTensor([len(example)], device=device).squeeze(0)
tensor = torch.LongTensor(example, device=device)
return tensor, length
|