ssa-perin / data /field /nested_field.py
larkkin's picture
Add application code and models, update README
8044721
#!/usr/bin/env python3
# coding=utf-8
import torch
from data.field.mini_torchtext.field import NestedField as TorchTextNestedField
class NestedField(TorchTextNestedField):
def pad(self, example):
self.nesting_field.include_lengths = self.include_lengths
if not self.include_lengths:
return self.nesting_field.pad(example)
sentence_length = len(example)
example, word_lengths = self.nesting_field.pad(example)
return example, sentence_length, word_lengths
def numericalize(self, arr, device=None):
numericalized = []
self.nesting_field.include_lengths = False
if self.include_lengths:
arr, sentence_length, word_lengths = arr
numericalized = self.nesting_field.numericalize(arr, device=device)
self.nesting_field.include_lengths = True
if self.include_lengths:
sentence_length = torch.tensor(sentence_length, dtype=self.dtype, device=device)
word_lengths = torch.tensor(word_lengths, dtype=self.dtype, device=device)
return (numericalized, sentence_length, word_lengths)
return numericalized
def build_vocab(self, *args, **kwargs):
sources = []
for arg in args:
if isinstance(arg, torch.utils.data.Dataset):
sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self]
else:
sources.append(arg)
flattened = []
for source in sources:
flattened.extend(source)
# just build vocab and does not load vector
self.nesting_field.build_vocab(*flattened, **kwargs)
super(TorchTextNestedField, self).build_vocab()
self.vocab.extend(self.nesting_field.vocab)
self.vocab.freqs = self.nesting_field.vocab.freqs.copy()
self.nesting_field.vocab = self.vocab