File size: 3,692 Bytes
8044721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import six
import json
from functools import reduce


class Example(object):
    """Defines a single training or test example.

    Stores each column of the example as an attribute.
    """
    @classmethod
    def fromJSON(cls, data, fields):
        ex = cls()
        obj = json.loads(data)

        for key, vals in fields.items():
            if vals is not None:
                if not isinstance(vals, list):
                    vals = [vals]

                for val in vals:
                    # for processing the key likes 'foo.bar'
                    name, field = val
                    ks = key.split('.')

                    def reducer(obj, key):
                        if isinstance(obj, list):
                            results = []
                            for data in obj:
                                if key not in data:
                                    # key error
                                    raise ValueError("Specified key {} was not found in "
                                                     "the input data".format(key))
                                else:
                                    results.append(data[key])
                            return results
                        else:
                            # key error
                            if key not in obj:
                                raise ValueError("Specified key {} was not found in "
                                                 "the input data".format(key))
                            else:
                                return obj[key]

                    v = reduce(reducer, ks, obj)
                    setattr(ex, name, field.preprocess(v))
        return ex

    @classmethod
    def fromdict(cls, data, fields):
        ex = cls()
        for key, vals in fields.items():
            if key not in data:
                raise ValueError("Specified key {} was not found in "
                                 "the input data".format(key))
            if vals is not None:
                if not isinstance(vals, list):
                    vals = [vals]
                for val in vals:
                    name, field = val
                    setattr(ex, name, field.preprocess(data[key]))
        return ex

    @classmethod
    def fromCSV(cls, data, fields, field_to_index=None):
        if field_to_index is None:
            return cls.fromlist(data, fields)
        else:
            assert(isinstance(fields, dict))
            data_dict = {f: data[idx] for f, idx in field_to_index.items()}
            return cls.fromdict(data_dict, fields)

    @classmethod
    def fromlist(cls, data, fields):
        ex = cls()
        for (name, field), val in zip(fields, data):
            if field is not None:
                if isinstance(val, six.string_types):
                    val = val.rstrip('\n')
                # Handle field tuples
                if isinstance(name, tuple):
                    for n, f in zip(name, field):
                        setattr(ex, n, f.preprocess(val))
                else:
                    setattr(ex, name, field.preprocess(val))
        return ex

    @classmethod
    def fromtree(cls, data, fields, subtrees=False):
        try:
            from nltk.tree import Tree
        except ImportError:
            print("Please install NLTK. "
                  "See the docs at http://nltk.org for more information.")
            raise
        tree = Tree.fromstring(data)
        if subtrees:
            return [cls.fromlist(
                [' '.join(t.leaves()), t.label()], fields) for t in tree.subtrees()]
        return cls.fromlist([' '.join(tree.leaves()), tree.label()], fields)