Spaces:
Runtime error
Runtime error
File size: 3,692 Bytes
8044721 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import six
import json
from functools import reduce
class Example(object):
"""Defines a single training or test example.
Stores each column of the example as an attribute.
"""
@classmethod
def fromJSON(cls, data, fields):
ex = cls()
obj = json.loads(data)
for key, vals in fields.items():
if vals is not None:
if not isinstance(vals, list):
vals = [vals]
for val in vals:
# for processing the key likes 'foo.bar'
name, field = val
ks = key.split('.')
def reducer(obj, key):
if isinstance(obj, list):
results = []
for data in obj:
if key not in data:
# key error
raise ValueError("Specified key {} was not found in "
"the input data".format(key))
else:
results.append(data[key])
return results
else:
# key error
if key not in obj:
raise ValueError("Specified key {} was not found in "
"the input data".format(key))
else:
return obj[key]
v = reduce(reducer, ks, obj)
setattr(ex, name, field.preprocess(v))
return ex
@classmethod
def fromdict(cls, data, fields):
ex = cls()
for key, vals in fields.items():
if key not in data:
raise ValueError("Specified key {} was not found in "
"the input data".format(key))
if vals is not None:
if not isinstance(vals, list):
vals = [vals]
for val in vals:
name, field = val
setattr(ex, name, field.preprocess(data[key]))
return ex
@classmethod
def fromCSV(cls, data, fields, field_to_index=None):
if field_to_index is None:
return cls.fromlist(data, fields)
else:
assert(isinstance(fields, dict))
data_dict = {f: data[idx] for f, idx in field_to_index.items()}
return cls.fromdict(data_dict, fields)
@classmethod
def fromlist(cls, data, fields):
ex = cls()
for (name, field), val in zip(fields, data):
if field is not None:
if isinstance(val, six.string_types):
val = val.rstrip('\n')
# Handle field tuples
if isinstance(name, tuple):
for n, f in zip(name, field):
setattr(ex, n, f.preprocess(val))
else:
setattr(ex, name, field.preprocess(val))
return ex
@classmethod
def fromtree(cls, data, fields, subtrees=False):
try:
from nltk.tree import Tree
except ImportError:
print("Please install NLTK. "
"See the docs at http://nltk.org for more information.")
raise
tree = Tree.fromstring(data)
if subtrees:
return [cls.fromlist(
[' '.join(t.leaves()), t.label()], fields) for t in tree.subtrees()]
return cls.fromlist([' '.join(tree.leaves()), tree.label()], fields)
|