Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# coding=utf-8 | |
import torch | |
from data.field.mini_torchtext.field import RawField | |
from data.field.mini_torchtext.vocab import Vocab | |
from collections import Counter | |
import types | |
class EdgeField(RawField): | |
def __init__(self): | |
super(EdgeField, self).__init__() | |
self.vocab = None | |
def process(self, edges, device=None): | |
edges = self.numericalize(edges) | |
tensor = self.pad(edges, device) | |
return tensor | |
def pad(self, edges, device): | |
tensor = torch.zeros(edges[0], edges[1], dtype=torch.long, device=device) | |
for edge in edges[-1]: | |
tensor[edge[0], edge[1]] = edge[2] | |
return tensor | |
def numericalize(self, arr): | |
def multi_map(array, function): | |
if isinstance(array, tuple): | |
return (array[0], array[1], function(array[2])) | |
elif isinstance(array, list): | |
return [multi_map(array[i], function) for i in range(len(array))] | |
else: | |
return array | |
if self.vocab is not None: | |
arr = multi_map(arr, lambda x: self.vocab.stoi[x] if x is not None else 0) | |
return arr | |
def build_vocab(self, *args): | |
def generate(l): | |
if isinstance(l, tuple): | |
yield l[2] | |
elif isinstance(l, list) or isinstance(l, types.GeneratorType): | |
for i in l: | |
yield from generate(i) | |
else: | |
return | |
counter = Counter() | |
sources = [] | |
for arg in args: | |
if isinstance(arg, torch.utils.data.Dataset): | |
sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self] | |
else: | |
sources.append(arg) | |
for x in generate(sources): | |
if x is not None: | |
counter.update([x]) | |
self.vocab = Vocab(counter, specials=[]) | |