File size: 7,539 Bytes
4fb0bd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import json
from collections import defaultdict
import logging

logger = logging.getLogger(__name__)


class ReaderForRelationDecoding():
    """Define text data reader and preprocess data for entity relation
    joint decoding on ACE dataset.
    """

    def __init__(self, file_path, is_test=False, max_len=dict()):
        """This function defines file path and some settings

        Arguments:
            file_path {str} -- file path

        Keyword Arguments:
            is_test {bool} -- indicate training or testing (default: {False})
            max_len {dict} -- max length for some namespace (default: {dict()})
        """

        self.file_path = file_path
        self.is_test = is_test
        self.max_len = dict(max_len)
        self.seq_lens = defaultdict(list)

    def __iter__(self):
        """Generator function
        """

        with open(self.file_path, 'r') as fin:
            for line in fin:
                line = json.loads(line)
                sentence = {}

                state, results = self.get_tokens(line)
                self.seq_lens['tokens'].append(len(results['tokens']))
                if not state or ('tokens' in self.max_len and len(results['tokens']) > self.max_len['tokens']
                                 and not self.is_test):
                    if not self.is_test:
                        continue
                sentence.update(results)

                state, results = self.get_wordpiece_tokens(line)
                self.seq_lens['wordpiece_tokens'].append(len(results['wordpiece_tokens']))
                if not state or ('wordpiece_tokens' in self.max_len
                                 and len(results['wordpiece_tokens']) > self.max_len['wordpiece_tokens']):
                    if not self.is_test:
                        continue
                sentence.update(results)

                # if len(sentence['tokens']) != len(sentence['wordpiece_tokens_index']):
                #     logger.error(
                #         "sentence id: {} wordpiece_tokens_index length is not equal to tokens.".format(line['sentId']))
                #     # logger.error(sentence['tokens'], sentence['wordpiece_tokens_index'])
                #     # logger.error("lengths: {}, {}, {}".format(len(sentence['tokens']), len(sentence['wordpiece_tokens_index']),
                #     #              len(line['labelIds'])))
                #     continue

                if len(sentence['wordpiece_tokens']) != len(sentence['wordpiece_segment_ids']):
                    logger.error(
                        "sentence id: {} wordpiece_tokens length is not equal to wordpiece_segment_ids.".
                            format(line['sentId']))
                    continue

                state, results = self.get_label(line, len(sentence['tokens']))
                for key, result in results.items():
                    self.seq_lens[key].append(len(result))
                    if key in self.max_len and len(result) > self.max_len[key]:
                        state = False
                if not state:
                    continue
                sentence.update(results)

                yield sentence

    def get_tokens(self, line):
        """This function splits text into tokens

        Arguments:
            line {dict} -- text

        Returns:
            bool -- execute state
            dict -- results: tokens
        """

        results = {}

        if 'sentText' not in line:
            logger.error("sentence id: {} doesn't contain 'sentText'.".format(line['sentId']))
            return False, results

        results['text'] = line['sentText']

        if 'tokens' in line:
            results['tokens'] = line['tokens']
        else:
            results['tokens'] = line['sentText'].strip().split(' ')

        return True, results

    def get_wordpiece_tokens(self, line):
        """This function splits wordpiece text into wordpiece tokens

        Arguments:
            line {dict} -- text

        Returns:
            bool -- execute state
            dict -- results: tokens
        """

        results = {}

        if 'wordpieceSentText' not in line or 'wordpieceTokensIndex' not in line or 'wordpieceSegmentIds' not in line:
            logger.error(
                "sentence id: {} doesn't contain 'wordpieceSentText' or 'wordpieceTokensIndex' or 'wordpieceSegmentIds'."
                    .format(line['sentId']))
            return False, results

        wordpiece_tokens = line['wordpieceSentText'].strip().split(' ')
        results['wordpiece_tokens'] = wordpiece_tokens
        results['wordpiece_tokens_index'] = [span[0] for span in line['wordpieceTokensIndex']]
        results['wordpiece_segment_ids'] = list(line['wordpieceSegmentIds'])

        return True, results

    def get_label(self, line, sentence_length):
        """This function constructs mapping relation from span to entity label
        and span pair to relation label, and joint entity relation label matrix.

        Arguments:
            line {dict} -- text
            sentence_length {int} -- sentence length

        Returns:
            bool -- execute state
            dict -- ent2rel: entity span mapping to entity label,
            span2rel: two entity span mapping to relation label,
            joint_label_matrix: joint entity relation label matrix
        """

        results = {}

        if 'entityMentions' not in line:
            logger.error("sentence id: {} doesn't contain 'entityMentions'.".format(line['sentId']))
            return False, results

        entity_pos = [0] * sentence_length
        idx2ent = {}
        for entity in line['entityMentions']:
            st, ed = entity['span_ids'][0], entity['span_ids'][-1]
            idx2ent[entity['emId']] = ((st, ed), entity['text'])
            if st > ed or st < 0 or st > sentence_length or ed < 0 or ed > sentence_length:
                logger.error("sentence id: {} offset error. st: {}, ed: {}".format(line['sentId'], st, ed))
                return False, results

            j = 0
            for i in range(st, ed):
                if entity_pos[i] != 0:
                    logger.error("sentence id: {} entity span overlap.".format(line['sentId']))
                    return False, results
                entity_pos[i] = 1
                j += 1

        if 'relationMentions' not in line:
            logger.error("sentence id: {} doesn't contain 'relationMentions'.".format(line['sentId']))
            return False, results

        for relation in line['relationMentions']:
            if relation['arg1']['emId'] not in idx2ent or relation['arg2']['emId'] not in idx2ent:
                logger.error("sentence id: {} entity not exists .".format(line['sentId']))
                continue

        if 'labelIds' not in line:
            logger.error("sentence id: {} doesn't contain 'labelIds'.".format(line['sentId']))
            return False, results

        if 'relationIds' not in line:
            logger.error("sentence id: {} doesn't contain 'relationIds'.".format(line['sentId']))
            return False, results

        if 'argumentIds' not in line:
            logger.error("sentence id: {} doesn't contain 'argumentIds'.".format(line['sentId']))
            return False, results

        results["label_ids"] = line["labelIds"]
        results["relation_ids"] = line["relationIds"]
        results["argument_ids"] = line["argumentIds"]

        return True, results

    def get_seq_lens(self):
        return self.seq_lens