nehalelkaref commited on
Commit
17654dc
·
1 Parent(s): 8a9a5d7

Delete utils.py

Browse files
Files changed (1) hide show
  1. utils.py +0 -415
utils.py DELETED
@@ -1,415 +0,0 @@
1
- import numpy as np
2
-
3
- class Config:
4
- def __init__(self):
5
- super(Config, self).__init__()
6
-
7
-
8
- def read_conll_ner(path):
9
- with open(path) as f:
10
- lines = f.readlines()
11
- unique_entries = []
12
- sentences = []
13
- curr_sentence = []
14
- for line in lines:
15
- if not line.strip():
16
- if curr_sentence:
17
- sentences.append(curr_sentence)
18
- curr_sentence = []
19
- continue
20
- entry = line.split()
21
- curr_sentence.append(entry)
22
- if not len(unique_entries):
23
- unique_entries = [[] for _ in entry[1:]]
24
- for e, list in zip(entry[1:], unique_entries):
25
- if e not in list:
26
- list.append(e)
27
- return [sentences] + unique_entries
28
-
29
-
30
- def read_pickled_conll(path):
31
- with open(path, "rb") as f:
32
- data = pickle.load(f)
33
- return data
34
-
35
-
36
- def split_conll_docs(conll_sents, skip_docstart=True):
37
- docs = []
38
- curr_doc = []
39
- for sent in conll_sents:
40
- if sent[0][0] == '-DOCSTART-':
41
- if curr_doc:
42
- docs.append(curr_doc)
43
- curr_doc = []
44
- if skip_docstart:
45
- continue
46
- curr_doc.append(sent)
47
- docs.append(curr_doc)
48
- return docs
49
-
50
-
51
- def create_context_data(docs, pos_col_id=1, tag_col_id=3, context_length=1, **kwargs):
52
- ctx_type = kwargs.get("ctx_type", "other")
53
- sep_token = kwargs.get("sep_token", "[SEP]")
54
- if ctx_type == "cand_titles":
55
- # create context for candidate titles scenario
56
- for doc in docs:
57
- doc["ctx_sent"] = doc["query"] + [sep_token] + f"<split>{sep_token}<split>".join([cand["doc_title"] for cand in doc["BM25_cands"]]).split("<split>")
58
- return docs
59
- if ctx_type == "cand_links":
60
- for doc in docs:
61
- doc_titles_list = f"<split>{sep_token}<split>".join([cand["doc_title"] for cand in doc["BM25_cands"]]).split("<split>")
62
- linked_titles_list = f"<split>{sep_token}<split>".join([linked for cand in doc["BM25_cands"] for linked in cand["linked_titles"]]).split("<split>")
63
- doc["ctx_sent"] = doc["query"] + [sep_token] + doc_titles_list + [sep_token] + linked_titles_list
64
- return docs
65
- if ctx_type == "raw_text":
66
- # create context for candidate raw text
67
- for doc in docs:
68
- doc["ctx_sent"] = [doc["query"] + [sep_token] + [cand["processed_text"]] for cand in doc["BM25_cands"]]
69
- return docs
70
- if ctx_type == 'matched_spans':
71
- matched_spans = kwargs.get('matched_spans')
72
- return [
73
- [[t[0] for t in d] + [t for ms in ms for t in [sep_token] + ms[1]], # sentence tokens + spans
74
- None, # pos tags
75
- [s[tag_col_id] for s in d] if tag_col_id > 0 else None, # ner tags
76
- [len(d)] # sentence length
77
- ]
78
- for d, ms in zip(docs, matched_spans)]
79
- if ctx_type == 'bm25_matched_spans':
80
- matched_spans = kwargs.get('matched_spans')
81
- pickled_data = kwargs.get('pickled_data')
82
- docs = [[[t[0] for t in d] + [t for ms in ms for t in [sep_token] + ms[1]], # sentence tokens + spans
83
- None, # pos tags
84
- [s[tag_col_id] for s in d], # ner tags
85
- [len(d)] # sentence length
86
- ]
87
- for d, ms in zip(docs, matched_spans)]
88
- for ms, doc in zip(docs,pickled_data):
89
- doc_titles_list = f"<split>{sep_token}<split>".join([cand["doc_title"] for cand in doc["BM25_cands"]]).split("<split>")
90
- linked_titles_list = f"<split>{sep_token}<split>".join([linked for cand in doc["BM25_cands"] for linked in cand["linked_titles"]]).split("<split>")
91
- ms[0] = ms[0] + [sep_token] + doc_titles_list + [sep_token] + linked_titles_list
92
- return docs
93
- if ctx_type == "infobox":
94
- infobox_keys_path = kwargs.get("infobox_keys_path")
95
- infobox_keys = read_pickled_conll(infobox_keys_path)
96
- if 'pred_spans' in docs[0]:
97
- docs = get_pred_ent_bounds(docs)
98
- for doc in docs:
99
- if 'pred_spans' in doc:
100
- ents = [' '.join(doc['query'][bd[0]:bd[1] + 1]) for bd in doc['pred_ent_bounds']]
101
- ents_wo_space = [''.join(doc['query'][bd[0]:bd[1] + 1]) for bd in doc['pred_ent_bounds']]
102
- else:
103
- ents = [' '.join(doc['query'][bd[0]:bd[1] + 1]) for bd in doc['ent_bounds']]
104
- ents_wo_space = [''.join(doc['query'][bd[0]:bd[1] + 1]) for bd in doc['ent_bounds']]
105
- ents = list(set(ents + ents_wo_space))
106
- infobox = [infobox_keys[en] for en in ents if en in infobox_keys and infobox_keys[en]]
107
- for ibs in infobox:
108
- ibs[0] = '[INFO] ' + ibs[0]
109
- ibs[-1] = ibs[-1] + ' [/INFO]'
110
- infobox = [i for j in infobox for i in j]
111
- doc["ctx_sent"] = doc["query"] + [sep_token] + infobox
112
- return docs
113
- # create context type for other scenarios
114
- res = []
115
- for doc in docs:
116
- ctx_len = context_length if context_length > 0 else len(doc)
117
- # for the last sentences loop around to the beginning for context
118
- padded_doc = doc + doc[:ctx_len]
119
- for i in range(len(doc)):
120
- res.append((
121
- [s[0] for sent in padded_doc[i:i+ctx_len] for s in sent],
122
- [s[pos_col_id] for sent in padded_doc[i:i+ctx_len] for s in sent] if pos_col_id > 0 else None,
123
- [s[tag_col_id] for sent in padded_doc[i:i+ctx_len] for s in sent],
124
- [len(sent) for sent in padded_doc[i:i+ctx_len]],
125
- {} # dictionary for extra context
126
- ))
127
- return res
128
-
129
-
130
- def calc_correct(sentence):
131
- gold_chunks = []
132
- parallel_chunks = []
133
- pred_chunks = []
134
- curr_gold_chunk = []
135
- curr_parallel_chunk = []
136
- curr_pred_chunk = []
137
- prev_tag = None
138
- for line in sentence:
139
- _, _, _, gt, pt = line
140
- curr_tag = None
141
- if '-' in pt:
142
- curr_tag = pt.split('-')[1]
143
- if gt.startswith('B'):
144
- if curr_gold_chunk:
145
- gold_chunks.append(curr_gold_chunk)
146
- parallel_chunks.append(curr_parallel_chunk)
147
- curr_gold_chunk = [gt]
148
- curr_parallel_chunk = [pt]
149
- elif gt.startswith('I') or (pt.startswith('I') and curr_tag == prev_tag
150
- and curr_gold_chunk):
151
- curr_gold_chunk.append(gt)
152
- curr_parallel_chunk.append(pt)
153
- elif gt.startswith('O') and pt.startswith('O'):
154
- if curr_gold_chunk:
155
- gold_chunks.append(curr_gold_chunk)
156
- parallel_chunks.append(curr_parallel_chunk)
157
- curr_gold_chunk = []
158
- curr_parallel_chunk = []
159
- if pt.startswith('O'):
160
- if curr_pred_chunk:
161
- pred_chunks.append(curr_pred_chunk)
162
- curr_pred_chunk = []
163
- elif pt.startswith('B'):
164
- if curr_pred_chunk:
165
- pred_chunks.append(curr_pred_chunk)
166
- curr_pred_chunk = [pt]
167
- prev_tag = curr_tag
168
- else:
169
- if prev_tag is not None and curr_tag != prev_tag:
170
- prev_tag = curr_tag
171
- if curr_pred_chunk:
172
- pred_chunks.append(curr_pred_chunk)
173
- curr_pred_chunk = []
174
- curr_pred_chunk.append(pt)
175
-
176
- if curr_gold_chunk:
177
- gold_chunks.append(curr_gold_chunk)
178
- parallel_chunks.append(curr_parallel_chunk)
179
- if curr_pred_chunk:
180
- pred_chunks.append(curr_pred_chunk)
181
- correct = sum([1 for gc, pc in zip(gold_chunks, parallel_chunks)
182
- if not len([1 for g, p in zip(gc, pc) if g != p])])
183
- correct_tagless = sum([1 for gc, pc in zip(gold_chunks, parallel_chunks)
184
- if not len([1 for g, p in zip(gc, pc) if g[0] != p[0]])])
185
- # return correct, gold_chunks, parallel_chunks, pred_chunks, ob1_correct, correct_tagless
186
- return {'correct': correct,
187
- 'correct_tagless': correct_tagless,
188
- 'gold_count': len(gold_chunks),
189
- 'pred_count': len(pred_chunks)}
190
-
191
-
192
- def tag_sentences(sentences):
193
- nlp = stanza.Pipeline(lang='en', processors='tokenize,pos', logging_level='WARNING')
194
- tagged_sents = []
195
- for sentence in sentences:
196
- n = nlp(sentence)
197
- tagged_sent = []
198
- for s in n.sentences:
199
- for w in s.words:
200
- tagged_sent.append([w.text, w.upos])
201
- tagged_sents.append(tagged_sent)
202
- return tagged_sents
203
-
204
-
205
- def extract_spans(sentence, tagless=False):
206
- spans_positions = []
207
- span_bounds = []
208
- all_bounds = []
209
- span_tags = []
210
- curr_tag = None
211
- curr_span = []
212
- curr_span_start = -1
213
- # span ids, span types
214
- for i, token in enumerate(sentence):
215
- if token.startswith('B'):
216
- if curr_span:
217
- spans_positions.append([curr_span, len(all_bounds)])
218
- span_bounds.append([curr_span_start, i-1])
219
- all_bounds.append([[curr_span_start, i - 1], 'E', len(all_bounds)])
220
- if not tagless:
221
- span_tags.append(token.split('-')[1])
222
- curr_span = []
223
- curr_tag = None
224
- curr_span.append(token)
225
- curr_tag = None if tagless else token.split('-')[1]
226
- curr_span_start = i
227
- elif token.startswith('I'):
228
- if not tagless:
229
- tag = token.split('-')[1]
230
- if tag != curr_tag and curr_tag is not None:
231
- spans_positions.append([curr_span, len(all_bounds)])
232
- span_bounds.append([curr_span_start, i - 1])
233
- span_tags.append(token.split('-')[1])
234
- all_bounds.append([[curr_span_start, i - 1], 'E', len(all_bounds)])
235
- curr_span = []
236
- curr_tag = tag
237
- curr_span_start = i
238
- elif curr_tag is None:
239
- curr_span = []
240
- curr_tag = tag
241
- curr_span_start = i
242
- elif not curr_span:
243
- curr_span_start = i
244
- curr_span.append(token)
245
- elif token.startswith('O') or token.startswith('-'):
246
- if curr_span:
247
- spans_positions.append([curr_span, len(all_bounds)])
248
- span_bounds.append([curr_span_start, i-1])
249
- all_bounds.append([[curr_span_start, i-1], 'E', len(all_bounds)])
250
- curr_span = []
251
- curr_tag = None
252
- all_bounds.append([[i], 'W', len(all_bounds)])
253
- # check if sentence ended with a span
254
- if curr_span:
255
- spans_positions.append([curr_span, len(all_bounds)])
256
- span_bounds.append([curr_span_start, len(sentence) - 1])
257
- all_bounds.append([[curr_span_start, len(sentence) - 1], 'E', len(all_bounds)])
258
- tagged_bounds = [[loc[0][0].split('-')[1] if '-' in loc[0][0] else loc[0][0], bound]
259
- for loc, bound in zip(spans_positions, span_bounds)]
260
- return spans_positions, span_bounds, all_bounds, tagged_bounds
261
-
262
-
263
- def ner_corpus_stats(corpus_path):
264
- onto_train_cols = read_conll_ner(corpus_path)
265
- tags = list(set([t.split('-')[1] for t in onto_train_cols[3] if '-' in t]))
266
- onto_train_spans = [extract_spans([t[3] for t in sent])[3] for sent in
267
- onto_train_cols[0]]
268
- span_lens = [span[1][1] - span[1][0] + 1 for sent in onto_train_spans for
269
- span in sent]
270
-
271
- len_stats = [span_lens.count(i + 1) / len(span_lens) for i in
272
- range(max(span_lens))]
273
- flat_spans = [span for sent in onto_train_spans for span in sent]
274
-
275
- tag_lens_dict = {k: [] for k in tags}
276
- tag_counts_dict = {k: 0 for k in tags}
277
- for span in flat_spans:
278
- span_length = span[1][1] - span[1][0] + 1
279
- span_tag = span[0][0].split('-')[1]
280
- tag_lens_dict[span_tag].append(span_length)
281
- tag_counts_dict[span_tag] += 1
282
-
283
- x = list(tag_counts_dict.items())
284
- x.sort(key=lambda l: l[1])
285
- tag_counts = [list(l) for l in x]
286
- for l in tag_counts:
287
- l[1] = l[1] / len(span_lens)
288
-
289
- tag_len_stats = {k: [v.count(i + 1) / len(v) for i in range(max(v))]
290
- for k, v in tag_lens_dict.items()}
291
- span_texts = [sent[span[1][0]:span[1][1] + 1]
292
- for sent, spans in zip(onto_train_cols[0], onto_train_spans)
293
- for span in spans]
294
- span_pos = [[span[0][-1].split('-')[1], '_'.join(t[1] for t in span)]
295
- for span in span_texts]
296
- unique_pos = list(set([span[1] for span in span_pos]))
297
- pos_dict = {k: 0 for k in unique_pos}
298
- for span in span_pos:
299
- pos_dict[span[1]] += 1
300
- unique_pos.sort(key=lambda l: pos_dict[l], reverse=True)
301
- pos_stats = [[p, pos_dict[p] / len(span_pos)] for p in unique_pos]
302
- tag_pos_dict = {kt: {kp: 0 for kp in unique_pos} for kt in tags}
303
- for span in span_pos:
304
- tag_pos_dict[span[0]][span[1]] += 1
305
- tag_pos_stats = {kt: [[p, tag_pos_dict[kt][p] / tag_counts_dict[kt]]
306
- for p in unique_pos] for kt in tags}
307
- for kt in tags:
308
- tag_pos_stats[kt].sort(key=lambda l: l[1], reverse=True)
309
-
310
- return len_stats, tag_len_stats, tag_counts, pos_stats, tag_pos_stats
311
-
312
-
313
- def filter_by_max_ents(sentences, max_ent_length):
314
- """
315
- Filters a given list of sentences and only returns the sentences that have
316
- named entities shorter than or equal to the given max_ent_length.
317
-
318
- :param sentences: sentences in conll format as extracted by read_conll_ner
319
- :param max_ent_length: The maximum number of tokens in an entity
320
- :return: a lits of sentences
321
- """
322
- filtered_sents = []
323
- for sent in sentences:
324
- sent_span_lens = [s[1] - s[0] + 1
325
- for s in extract_spans([t[3] for t in sent])[1]]
326
- if not sent_span_lens or max(sent_span_lens) <= max_ent_length:
327
- filtered_sents.append(sent)
328
- return filtered_sents
329
-
330
-
331
- def get_pred_ent_bounds(docs):
332
- for doc in docs:
333
- eb = []
334
- count = 0
335
- for p_eb in doc['pred_spans']:
336
- if p_eb == 'B':
337
- eb.append([count,count])
338
- elif p_eb == 'I' and len(eb) > 0:
339
- eb[-1][1] = count
340
- count += 1
341
- doc['pred_ent_bounds'] = eb
342
- return docs
343
-
344
- def enumerate_spans(batch):
345
-
346
- enumerated_spans_batch = []
347
-
348
- for idx in range(0, len(batch)):
349
- sentence_length = batch[idx]
350
- enumerated_spans = []
351
- for x in range(len(sentence_length)):
352
- for y in range(x, len(sentence_length)):
353
- enumerated_spans.append([x,y])
354
-
355
- enumerated_spans_batch.append(enumerated_spans)
356
-
357
- return enumerated_spans_batch
358
-
359
- def compact_span_enumeration(batch):
360
- sentence_lengths = [len(b) for b in batch]
361
- enumerated_spans = [[[x, y]
362
- for y in range(0, sentence_length)
363
- for x in range(sentence_length)]
364
- for sentence_length in sentence_lengths]
365
- return enumerated_spans
366
-
367
- def preprocess_data(data):
368
- clean_data = []
369
- for sample in data:
370
- clean_tokens = [araby.strip_tashkeel(token) for token in sample[0]]
371
- clean_tokens = [araby.strip_tatweel(token) for token in clean_tokens]
372
- clean_sample = [clean_tokens]
373
- clean_sample.extend(sample[1:])
374
- clean_data.append(clean_sample)
375
- return clean_data
376
-
377
-
378
- def generate_targets(enumerated_spans, sentences):
379
- #### could be refactored into a helper function ####
380
- extracted_spans= [extract_spans(sentence,True)[3] for sentence in sentences]
381
- target_locations = []
382
-
383
- for span in extracted_spans:
384
- sentence_locations = []
385
- for location in span:
386
- sentence_locations.append(location[1])
387
- target_locations.append(sentence_locations)
388
-
389
- #### could be refactored into a helper function ####
390
-
391
-
392
- targets= []
393
-
394
- for span, location_list in zip(enumerated_spans, target_locations):
395
- span_arr = np.zeros_like(span).tolist()
396
- target_indices = [span.index(span_location) for
397
- span_location in location_list]
398
-
399
-
400
- for idx in target_indices:
401
- span_arr[idx] =1
402
-
403
- span_arr = [0 if x!=1 else x for x in span_arr]
404
- targets.append(list(span_arr))
405
-
406
- return targets
407
-
408
- def label_tags(tags):
409
- output_tags = []
410
- for tag in tags:
411
- if (tag == "O"):
412
- output_tags.append(0)
413
- else:
414
- output_tags.append(1)
415
- return output_tags