larkkin's picture
Add application code and models, update README
8044721
raw
history blame
3.69 kB
import six
import json
from functools import reduce
class Example(object):
"""Defines a single training or test example.
Stores each column of the example as an attribute.
"""
@classmethod
def fromJSON(cls, data, fields):
ex = cls()
obj = json.loads(data)
for key, vals in fields.items():
if vals is not None:
if not isinstance(vals, list):
vals = [vals]
for val in vals:
# for processing the key likes 'foo.bar'
name, field = val
ks = key.split('.')
def reducer(obj, key):
if isinstance(obj, list):
results = []
for data in obj:
if key not in data:
# key error
raise ValueError("Specified key {} was not found in "
"the input data".format(key))
else:
results.append(data[key])
return results
else:
# key error
if key not in obj:
raise ValueError("Specified key {} was not found in "
"the input data".format(key))
else:
return obj[key]
v = reduce(reducer, ks, obj)
setattr(ex, name, field.preprocess(v))
return ex
@classmethod
def fromdict(cls, data, fields):
ex = cls()
for key, vals in fields.items():
if key not in data:
raise ValueError("Specified key {} was not found in "
"the input data".format(key))
if vals is not None:
if not isinstance(vals, list):
vals = [vals]
for val in vals:
name, field = val
setattr(ex, name, field.preprocess(data[key]))
return ex
@classmethod
def fromCSV(cls, data, fields, field_to_index=None):
if field_to_index is None:
return cls.fromlist(data, fields)
else:
assert(isinstance(fields, dict))
data_dict = {f: data[idx] for f, idx in field_to_index.items()}
return cls.fromdict(data_dict, fields)
@classmethod
def fromlist(cls, data, fields):
ex = cls()
for (name, field), val in zip(fields, data):
if field is not None:
if isinstance(val, six.string_types):
val = val.rstrip('\n')
# Handle field tuples
if isinstance(name, tuple):
for n, f in zip(name, field):
setattr(ex, n, f.preprocess(val))
else:
setattr(ex, name, field.preprocess(val))
return ex
@classmethod
def fromtree(cls, data, fields, subtrees=False):
try:
from nltk.tree import Tree
except ImportError:
print("Please install NLTK. "
"See the docs at http://nltk.org for more information.")
raise
tree = Tree.fromstring(data)
if subtrees:
return [cls.fromlist(
[' '.join(t.leaves()), t.label()], fields) for t in tree.subtrees()]
return cls.fromlist([' '.join(tree.leaves()), tree.label()], fields)