import torch from data.field.mini_torchtext.field import RawField from data.field.mini_torchtext.vocab import Vocab from collections import Counter class LabelField(RawField): def __self__(self, preprocessing): super(LabelField, self).__init__(preprocessing=preprocessing) self.vocab = None def build_vocab(self, *args, **kwargs): sources = [] for arg in args: if isinstance(arg, torch.utils.data.Dataset): sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self] else: sources.append(arg) counter = Counter() for data in sources: for x in data: counter.update(x) self.vocab = Vocab(counter, specials=[]) def process(self, example, device=None): tensor, lengths = self.numericalize(example, device=device) return tensor, lengths def numericalize(self, example, device=None): example = [self.vocab.stoi[x] + 1 for x in example] length = torch.LongTensor([len(example)], device=device).squeeze(0) tensor = torch.LongTensor(example, device=device) return tensor, length