Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# coding=utf-8 | |
import torch | |
from data.parser.json_parser import example_from_json | |
class AbstractParser(torch.utils.data.Dataset): | |
def __init__(self, fields, data, filter_pred=None): | |
super(AbstractParser, self).__init__() | |
self.examples = [example_from_json(d, fields) for _, d in sorted(data.items())] | |
if isinstance(fields, dict): | |
fields, field_dict = [], fields | |
for field in field_dict.values(): | |
if isinstance(field, list): | |
fields.extend(field) | |
else: | |
fields.append(field) | |
if filter_pred is not None: | |
make_list = isinstance(self.examples, list) | |
self.examples = filter(filter_pred, self.examples) | |
if make_list: | |
self.examples = list(self.examples) | |
self.fields = dict(fields) | |
# Unpack field tuples | |
for n, f in list(self.fields.items()): | |
if isinstance(n, tuple): | |
self.fields.update(zip(n, f)) | |
del self.fields[n] | |
def __getitem__(self, i): | |
item = self.examples[i] | |
processed_item = {} | |
for (name, field) in self.fields.items(): | |
if field is not None: | |
processed_item[name] = field.process(getattr(item, name), device=None) | |
return processed_item | |
def __len__(self): | |
return len(self.examples) | |
def get_examples(self, attr): | |
if attr in self.fields: | |
for x in self.examples: | |
yield getattr(x, attr) | |