ssa-perin / data /field /label_field.py
larkkin's picture
Add supporting code from perin
7daaa6b
import torch
from data.field.mini_torchtext.field import RawField
from data.field.mini_torchtext.vocab import Vocab
from collections import Counter
class LabelField(RawField):
def __self__(self, preprocessing):
super(LabelField, self).__init__(preprocessing=preprocessing)
self.vocab = None
def build_vocab(self, *args, **kwargs):
sources = []
for arg in args:
if isinstance(arg, torch.utils.data.Dataset):
sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self]
else:
sources.append(arg)
counter = Counter()
for data in sources:
for x in data:
counter.update(x)
self.vocab = Vocab(counter, specials=[])
def process(self, example, device=None):
tensor, lengths = self.numericalize(example, device=device)
return tensor, lengths
def numericalize(self, example, device=None):
example = [self.vocab.stoi[x] + 1 for x in example]
length = torch.LongTensor([len(example)], device=device).squeeze(0)
tensor = torch.LongTensor(example, device=device)
return tensor, length