#!/usr/bin/env python3 # coding=utf-8 import torch from data.field.mini_torchtext.field import RawField from data.field.mini_torchtext.vocab import Vocab from collections import Counter import types class EdgeLabelField(RawField): def process(self, edges, device=None): edges, masks = self.numericalize(edges) edges, masks = self.pad(edges, masks, device) return edges, masks def pad(self, edges, masks, device): n_labels = len(self.vocab) tensor = torch.zeros(edges[0], edges[1], n_labels, dtype=torch.long, device=device) mask_tensor = torch.zeros(edges[0], edges[1], dtype=torch.bool, device=device) for edge in edges[-1]: tensor[edge[0], edge[1], edge[2]] = 1 for mask in masks[-1]: mask_tensor[mask[0], mask[1]] = mask[2] return tensor, mask_tensor def numericalize(self, arr): def multi_map(array, function): if isinstance(array, tuple): return (array[0], array[1], function(array[2])) elif isinstance(array, list): return [multi_map(array[i], function) for i in range(len(array))] else: return array mask = multi_map(arr, lambda x: x is None) arr = multi_map(arr, lambda x: self.vocab.stoi[x] if x in self.vocab.stoi else 0) return arr, mask def build_vocab(self, *args): def generate(l): if isinstance(l, tuple): yield l[2] elif isinstance(l, list) or isinstance(l, types.GeneratorType): for i in l: yield from generate(i) else: return counter = Counter() sources = [] for arg in args: if isinstance(arg, torch.utils.data.Dataset): sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self] else: sources.append(arg) for x in generate(sources): if x is not None: counter.update([x]) self.vocab = Vocab(counter, specials=[])