nehalelkaref commited on
Commit
6b35cc5
·
1 Parent(s): 62240fd

Upload 3 files

Browse files
Files changed (3) hide show
  1. network.py +333 -0
  2. utils.py +420 -0
  3. validate.py +168 -0
network.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn.utils.rnn import pad_sequence
6
+ from torch.nn.functional import cross_entropy, binary_cross_entropy
7
+ from tqdm.auto import tqdm
8
+
9
+ from .utils import Config, extract_spans, generate_targets
10
+ from .representation import TransformerRepresentation
11
+ from .layers import SpanEnumerationLayer
12
+
13
+ DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+
16
+ class SpanNet(nn.Module):
17
+ def __init__(self, **kwargs):
18
+ super(SpanNet, self).__init__()
19
+ self.config = Config()
20
+ self.config.pos = kwargs.get('pos', None) # pos
21
+ self.config.dp = kwargs.get('dp', 0.3) # dp
22
+ self.config.transformer_model_name = kwargs.get('transformer_model_name', 'bert-base-uncased')
23
+ self.config.token_pooling = kwargs.get('token_pooling', 'sum')
24
+ self.device = kwargs.get('device', DEFAULT_DEVICE)
25
+
26
+ self.config.repr_type = kwargs.get('repr_type', 'token_classification')
27
+ assert self.config.repr_type in ['token_classification',
28
+ 'span_enumeration'], 'Invalid representaton type'
29
+
30
+
31
+ self.transformer = TransformerRepresentation(
32
+ model_name=self.config.transformer_model_name,
33
+ device=self.device).to(self.device)
34
+
35
+ self.transformer_dim = self.transformer.embedding_dim
36
+ if self.config.pos:
37
+ self.transformer.add_special_tokens([f'[{p}]' for p in self.config.pos])
38
+
39
+ self.span_tags = ['B', 'I', 'O'] # , '-']
40
+
41
+ self.enumeration_layer = SpanEnumerationLayer()
42
+ output_size = {'token_classification': len(self.span_tags),
43
+ 'span_enumeration': 1}
44
+ self.span_output_layer = nn.Sequential(
45
+ nn.Linear(self.transformer_dim, self.transformer_dim),
46
+ nn.ReLU(), nn.Dropout(p=self.config.dp),
47
+ nn.Linear(self.transformer_dim, output_size[self.config.repr_type]))
48
+ def to_dict(self):
49
+ return {
50
+ 'model_config': self.config.__dict__,
51
+ 'model_state_dict': self.state_dict()
52
+ }
53
+
54
+ @classmethod
55
+ def load_model(cls, model_path, device=DEFAULT_DEVICE):
56
+ res = torch.load(model_path, device)
57
+ model = cls(**res['model_config'])
58
+ model.load_state_dict(res['model_state_dict'])
59
+ model.eval()
60
+ return model
61
+
62
+ @classmethod
63
+ def preds_to_sequences(self, predictions, enumerations, length):
64
+ # assumes the function is applied per tensor/sample
65
+ # sort descendindly
66
+ enum_preds = {predictions[idx].item(): enumerations[idx] for idx in range(len(enumerations))}
67
+ sorted_enum_preds = dict(sorted(enum_preds.items(), key=lambda val:val[1], reverse=True))
68
+
69
+ # look for clashes
70
+ spans = [sorted_enum_preds[key] for key in sorted_enum_preds.keys()]
71
+ spans_copy = [sorted_enum_preds[key] for key in sorted_enum_preds.keys()]
72
+
73
+ i=0
74
+ while(i!=(len(spans_copy))):
75
+ filtered_spans = []
76
+ s,e = spans_copy[i]
77
+ for j in range(i+1, len(spans_copy)):
78
+ sj,ej = spans_copy[j]
79
+ if((sj<s<=ej<e) or (sj<s<=ej<=e) or ((s<sj)&(e<ej))):
80
+ filtered_spans.append(spans_copy[j])
81
+ i+=1
82
+ spans_copy = [span for span in spans_copy if span not in filtered_spans]
83
+
84
+ chosen_indices = [spans.index(span) for span in spans_copy]
85
+ filtered_enum_preds = {list(sorted_enum_preds.keys())[idx]:
86
+ sorted_enum_preds[list(sorted_enum_preds.keys())[idx]]
87
+ for idx in chosen_indices}
88
+ # assign BIO to spans
89
+ tagged_seq = ['O']*length
90
+ for idx in range(len(spans_copy)):
91
+ s,e =spans_copy[idx]
92
+
93
+ tagged_seq[s]='B'
94
+
95
+ if((e-s)>0):
96
+ bounds = (e+1)-(s+1)
97
+ tagged_seq[s+1:e+1] =['I'] * bounds
98
+
99
+ return tagged_seq
100
+
101
+ def save_model(self, output_path):
102
+ torch.save(self.to_dict(), output_path)
103
+
104
+ def _extract_sentence_vectors(self, sentences, pos=None):
105
+ if pos and self.config.pos:
106
+ sentences = [[f'[{p}] {s}' for s, p in zip(s, p)]
107
+ for s, p in zip(sentences, pos)]
108
+ outs = self.transformer(sentences, is_pretokenized=True,
109
+ token_pooling=self.config.token_pooling)
110
+ return outs.pooled_tokens
111
+
112
+ def forward(self, sentences, pos=None, tags=None, **kwargs):
113
+ out_dict = {}
114
+ embs = self._extract_sentence_vectors(sentences, pos)
115
+ if kwargs.get('output_word_vecs', False):
116
+ out_dict['word_vecs'] = embeddings
117
+
118
+ lens = [len(s) for s in embs]
119
+
120
+ if self.config.repr_type == 'span_enumeration':
121
+ embs, enumerations = self.enumeration_layer(embs, lens)
122
+ lens = [len(e) for e in enumerations]
123
+
124
+ input_layer = pad_sequence(embs, batch_first=True)
125
+
126
+ span_scores = [torch.unbind(f)[:l]
127
+ for f, l in zip(self.span_output_layer(input_layer), lens)]
128
+
129
+
130
+ if kwargs.get('output_span_scores', False):
131
+ out_dict['span_scores'] = span_scores
132
+ if self.config.repr_type == "token_classification":
133
+ pred_span_ids = [[torch.argmax(s) for s in sc] for sc in span_scores]
134
+ pred_span_tags = [[self.span_tags[idx] for idx in sequence]
135
+ for sequence in pred_span_ids]
136
+ out_dict['pred_tags'] = pred_span_tags
137
+ else:
138
+ lens = [len(s) for s in sentences]
139
+ tagged_seq=[]
140
+ prev_enum = 0
141
+ for idx in range(0, len(enumerations)):
142
+ enum = enumerations[idx]
143
+ length =lens[idx]
144
+
145
+ scores = flat_scores[prev_enum :len(enum)+ prev_enum]
146
+
147
+ prev_enum = len(enum)
148
+ tagged_seq.append(self.preds_to_sequences(scores, enum, length))
149
+ out_dict['pred_tags'] = tagged_seq
150
+
151
+
152
+ if tags is None:
153
+ return out_dict
154
+
155
+ if self.config.repr_type == 'span_enumeration':
156
+ targets = generate_targets(enumerations, tags)
157
+ targets = torch.Tensor([t for st in targets for t in st])
158
+ flat_scores = torch.Tensor([t for score in span_scores for t in score])
159
+ print('before: ', flat_scores.shape)
160
+ if self.config.repr_type == 'token_classification':
161
+ # limit the targets of each sentence to the words not truncated during tokenization
162
+ targets = torch.cat(
163
+ [torch.tensor([self.span_tags.index(t[0]) for t, _ in zip(tg, sc)])
164
+ for tg, sc in zip(tags, span_scores)]).to(self.device)
165
+ flat_scores = torch.stack([s for tg, sc in zip(tags, span_scores) for _, s in zip(tg, sc)])
166
+
167
+
168
+ if self.config.repr_type == 'span_enumeration':
169
+ span_loss = binary_cross_entropy(flat_scores.sigmoid(), targets)
170
+
171
+ else:
172
+ span_loss = cross_entropy(flat_scores, targets)
173
+ out_dict['loss'] = span_loss
174
+ return out_dict
175
+
176
+ def from_span_scores(self, span_scores):
177
+ pred_span_ids = [[torch.argmax(s) for s in sc] for sc in span_scores]
178
+ return [[self.span_tags[idx] for idx in sequence]
179
+ for sequence in pred_span_ids]
180
+
181
+
182
+ class EntNet(nn.Module):
183
+ def __init__(self, **kwargs):
184
+ super(EntNet, self).__init__()
185
+ self.config = Config()
186
+ self.span_net = kwargs.get('span_net')
187
+ self.config.tune_span_net = kwargs.get('tune_span_net', False)
188
+ self.config.use_span_emb = kwargs.get('use_span_emb', False)
189
+ self.config.use_ent_markers = kwargs.get('use_ent_markers', False)
190
+ # it is possible to tune span_net without using its embeddings
191
+ if self.span_net and not self.config.tune_span_net:
192
+ for p in self.span_net.parameters():
193
+ p.requires_grad = False
194
+ self.config.ent_tags = self.ent_tags = kwargs.get('ent_tags')
195
+ self.config.pos = kwargs.get('pos', None)
196
+ self.config.dp = kwargs.get('dp', 0.3)
197
+ self.config.transformer_model_name = kwargs.get('transformer_model_name', 'bert-base-uncased')
198
+ self.config.token_pooling = kwargs.get('token_pooling', 'first')
199
+ self.device = kwargs.get('device', DEFAULT_DEVICE)
200
+
201
+ self.transformer = TransformerRepresentation(
202
+ model_name=self.config.transformer_model_name,
203
+ device=self.device).to(self.device)
204
+ self.transformer_dim = self.transformer.embedding_dim
205
+
206
+ self.transformer.add_special_tokens(['[ENT]', '[/ENT]'])
207
+ self.transformer.add_special_tokens(['[INFO]', '[/INFO]'])
208
+ if self.config.pos:
209
+ self.transformer.add_special_tokens(
210
+ ['['+p+']' for p in self.config.pos])
211
+
212
+ self.ent_output_layer = nn.Sequential(
213
+ nn.Linear(2*self.transformer_dim, 2*self.transformer_dim),
214
+ nn.ReLU(), nn.Dropout(p=self.config.dp),
215
+ nn.Linear(2*self.transformer_dim, len(self.config.ent_tags)))
216
+
217
+ def to_dict(self):
218
+ return {
219
+ 'model_config': self.config.__dict__,
220
+ 'span_net_config': self.span_net.config.__dict__ if self.span_net is not None else None,
221
+ 'model_state_dict': self.state_dict()
222
+ }
223
+
224
+ @classmethod
225
+ def load_model(cls, model_path, device=DEFAULT_DEVICE):
226
+ res = torch.load(model_path, device)
227
+ span_net = SpanNet(**res['span_net_config']) if res['span_net_config'] is not None else None
228
+ model = cls(span_net=span_net, **res['model_config'])
229
+ model.load_state_dict(res['model_state_dict'])
230
+ model.eval()
231
+ return model
232
+
233
+ def save_model(self, output_path):
234
+ torch.save(self.to_dict(), output_path)
235
+
236
+ def _extract_sentence_vectors(self, sentences, pos=None, ent_bounds=None):
237
+ if pos and self.config.pos:
238
+ sentences = [[f'[{p}] {s}' for s, p in zip(s, p)]
239
+ for s, p in zip(sentences, pos)]
240
+ if ent_bounds and self.config.use_ent_markers:
241
+ for sent, sent_ents in zip(sentences, ent_bounds):
242
+ for ent in sent_ents:
243
+ sent[ent[0]] = f'[ENT] {sent[ent[0]]}'
244
+ sent[ent[1]] = f'{sent[ent[1]]} [/ENT]'
245
+
246
+ outs = self.transformer(sentences, is_pretokenized=True,
247
+ token_pooling=self.config.token_pooling)
248
+ return outs.pooled_tokens
249
+
250
+ def forward(self, sentences, pos=None, tags=None, **kwargs):
251
+ out_dict = {}
252
+ pred_span_seqs = kwargs.get('pred_tags', None)
253
+ if pred_span_seqs is None:
254
+ span_out = self.span_net(sentences, pos=pos,
255
+ output_word_vecs=self.config.use_span_emb,
256
+ tags=tags if self.config.tune_span_net else None)
257
+ pred_span_seqs = span_out['pred_tags']
258
+ bounds = [[e[1] for e in extract_spans(t, tagless=True)[3]]
259
+ for t in pred_span_seqs]
260
+ if tags is not None:
261
+ gold_spans = [[e for e in extract_spans(t, tagless=True)[3]]
262
+ for t in tags]
263
+ matches = [[[g[0]
264
+ for g in golds if p[0] == g[1][0] and p[1] == g[1][1]]
265
+ for p in preds]
266
+ for preds, golds in zip(bounds, gold_spans)]
267
+ targets = [[span_matches[0] if len(span_matches) == 1 else 'O'
268
+ for span_matches in sent_matches]
269
+ for sent_matches in matches]
270
+
271
+ sentences = [sent + [t for bd in sent_bounds
272
+ for t in [self.transformer.tokenizer.sep_token] + sent[bd[0]:bd[1] + 1]]
273
+ + [self.transformer.tokenizer.sep_token]
274
+ for sent, sent_bounds in zip(sentences, bounds)]
275
+ sep_ids = [[i for i, s in enumerate(sent) if s == self.transformer.tokenizer.sep_token]
276
+ for sent in sentences]
277
+ embs = self._extract_sentence_vectors(sentences, pos, bounds)
278
+ if kwargs.get('output_word_vecs', False):
279
+ out_dict['word_vecs'] = embs
280
+
281
+ span_vecs = [
282
+ torch.stack([torch.cat((torch.sum(e[b[0]:b[1] + 1], dim=0),
283
+ torch.sum(e[spi[i]:spi[i+1]+1], dim=0))) for i, b in enumerate(bd)])
284
+ if bd else torch.zeros((0)).to(self.device)
285
+ for e, bd, spi in zip(embs, bounds, sep_ids)]
286
+ ent_scores = [self.ent_output_layer(sv) if len(sv) else sv
287
+ for sv in span_vecs]
288
+ if kwargs.get('output_ent_scores', False):
289
+ out_dict['ent_scores'] = ent_scores
290
+ out_dict['bounds'] = bounds
291
+ if tags is None:
292
+ max_tags = [[self.ent_tags[torch.argmax(e)] for e in es]
293
+ for es in ent_scores]
294
+ # reconstruct sequences
295
+ sent_lens = [len(s) for s in sentences]
296
+ combined_sequences = []
297
+ for mt, bnd, lens in zip(max_tags, bounds, sent_lens):
298
+ x = ['O' for _ in range(lens)]
299
+ for t, b in zip(mt, bnd):
300
+ x[b[0]] = 'O' if t == 'O' else f'B-{t}'
301
+ for i in range(b[0] + 1, b[1] + 1):
302
+ x[i] = 'O' if t == 'O' else f'I-{t}'
303
+ combined_sequences.append(x)
304
+ out_dict['pred_tags'] = combined_sequences
305
+ return out_dict
306
+
307
+ ent_targs = torch.tensor([self.ent_tags.index(t)
308
+ for targ in targets for t in targ],
309
+ dtype=torch.long).to(self.device)
310
+ ent_preds = torch.cat(ent_scores)
311
+ if not len(ent_preds):
312
+ out_dict['loss'] = None
313
+ return out_dict
314
+ ent_loss = cross_entropy(ent_preds, ent_targs)
315
+ out_dict['loss'] = ent_loss
316
+ if self.config.tune_span_net:
317
+ out_dict['loss'] += span_out['loss']
318
+ return out_dict
319
+
320
+ def from_ent_scores(self, ent_scores, sentences, bounds):
321
+ max_tags = [[self.ent_tags[torch.argmax(e)] for e in es]
322
+ for es in ent_scores]
323
+ # reconstruct sequences
324
+ sent_lens = [len(s) for s in sentences]
325
+ combined_sequences = []
326
+ for mt, bnd, lens in zip(max_tags, bounds, sent_lens):
327
+ x = ['O' for _ in range(lens)]
328
+ for t, b in zip(mt, bnd):
329
+ x[b[0]] = 'O' if t == 'O' else f'B-{t}'
330
+ for i in range(b[0] + 1, b[1] + 1):
331
+ x[i] = 'O' if t == 'O' else f'I-{t}'
332
+ combined_sequences.append(x)
333
+ return combined_sequences
utils.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import pyarabic.araby as araby
3
+ # import stanza
4
+ import numpy as np
5
+
6
+ class Config:
7
+ def __init__(self):
8
+ super(Config, self).__init__()
9
+
10
+
11
+ def read_conll_ner(path):
12
+ with open(path) as f:
13
+ lines = f.readlines()
14
+ unique_entries = []
15
+ sentences = []
16
+ curr_sentence = []
17
+ for line in lines:
18
+ if not line.strip():
19
+ if curr_sentence:
20
+ sentences.append(curr_sentence)
21
+ curr_sentence = []
22
+ continue
23
+ if line.startswith('#') and not curr_sentence:
24
+ continue
25
+ entry = line.split()
26
+ curr_sentence.append(entry)
27
+ if not len(unique_entries):
28
+ unique_entries = [[] for _ in entry[1:]]
29
+ for e, list in zip(entry[1:], unique_entries):
30
+ if e not in list:
31
+ list.append(e)
32
+ return [sentences] + unique_entries
33
+
34
+
35
+ def read_pickled_conll(path):
36
+ with open(path, "rb") as f:
37
+ data = pickle.load(f)
38
+ return data
39
+
40
+
41
+ def split_conll_docs(conll_sents, skip_docstart=True):
42
+ docs = []
43
+ curr_doc = []
44
+ for sent in conll_sents:
45
+ if sent[0][0] == '-DOCSTART-':
46
+ if curr_doc:
47
+ docs.append(curr_doc)
48
+ curr_doc = []
49
+ if skip_docstart:
50
+ continue
51
+ curr_doc.append(sent)
52
+ docs.append(curr_doc)
53
+ return docs
54
+
55
+
56
+ def create_context_data(docs, pos_col_id=1, tag_col_id=3, context_length=1, **kwargs):
57
+ ctx_type = kwargs.get("ctx_type", "other")
58
+ sep_token = kwargs.get("sep_token", "[SEP]")
59
+ if ctx_type == "cand_titles":
60
+ # create context for candidate titles scenario
61
+ for doc in docs:
62
+ doc["ctx_sent"] = doc["query"] + [sep_token] + f"<split>{sep_token}<split>".join([cand["doc_title"] for cand in doc["BM25_cands"]]).split("<split>")
63
+ return docs
64
+ if ctx_type == "cand_links":
65
+ for doc in docs:
66
+ doc_titles_list = f"<split>{sep_token}<split>".join([cand["doc_title"] for cand in doc["BM25_cands"]]).split("<split>")
67
+ linked_titles_list = f"<split>{sep_token}<split>".join([linked for cand in doc["BM25_cands"] for linked in cand["linked_titles"]]).split("<split>")
68
+ doc["ctx_sent"] = doc["query"] + [sep_token] + doc_titles_list + [sep_token] + linked_titles_list
69
+ return docs
70
+ if ctx_type == "raw_text":
71
+ # create context for candidate raw text
72
+ for doc in docs:
73
+ doc["ctx_sent"] = [doc["query"] + [sep_token] + [cand["processed_text"]] for cand in doc["BM25_cands"]]
74
+ return docs
75
+ if ctx_type == 'matched_spans':
76
+ matched_spans = kwargs.get('matched_spans')
77
+ return [
78
+ [[t[0] for t in d] + [t for ms in ms for t in [sep_token] + ms[1]], # sentence tokens + spans
79
+ None, # pos tags
80
+ [s[tag_col_id] for s in d] if tag_col_id > 0 else None, # ner tags
81
+ [len(d)] # sentence length
82
+ ]
83
+ for d, ms in zip(docs, matched_spans)]
84
+ if ctx_type == 'bm25_matched_spans':
85
+ matched_spans = kwargs.get('matched_spans')
86
+ pickled_data = kwargs.get('pickled_data')
87
+ docs = [[[t[0] for t in d] + [t for ms in ms for t in [sep_token] + ms[1]], # sentence tokens + spans
88
+ None, # pos tags
89
+ [s[tag_col_id] for s in d], # ner tags
90
+ [len(d)] # sentence length
91
+ ]
92
+ for d, ms in zip(docs, matched_spans)]
93
+ for ms, doc in zip(docs,pickled_data):
94
+ doc_titles_list = f"<split>{sep_token}<split>".join([cand["doc_title"] for cand in doc["BM25_cands"]]).split("<split>")
95
+ linked_titles_list = f"<split>{sep_token}<split>".join([linked for cand in doc["BM25_cands"] for linked in cand["linked_titles"]]).split("<split>")
96
+ ms[0] = ms[0] + [sep_token] + doc_titles_list + [sep_token] + linked_titles_list
97
+ return docs
98
+ if ctx_type == "infobox":
99
+ infobox_keys_path = kwargs.get("infobox_keys_path")
100
+ infobox_keys = read_pickled_conll(infobox_keys_path)
101
+ if 'pred_spans' in docs[0]:
102
+ docs = get_pred_ent_bounds(docs)
103
+ for doc in docs:
104
+ if 'pred_spans' in doc:
105
+ ents = [' '.join(doc['query'][bd[0]:bd[1] + 1]) for bd in doc['pred_ent_bounds']]
106
+ ents_wo_space = [''.join(doc['query'][bd[0]:bd[1] + 1]) for bd in doc['pred_ent_bounds']]
107
+ else:
108
+ ents = [' '.join(doc['query'][bd[0]:bd[1] + 1]) for bd in doc['ent_bounds']]
109
+ ents_wo_space = [''.join(doc['query'][bd[0]:bd[1] + 1]) for bd in doc['ent_bounds']]
110
+ ents = list(set(ents + ents_wo_space))
111
+ infobox = [infobox_keys[en] for en in ents if en in infobox_keys and infobox_keys[en]]
112
+ for ibs in infobox:
113
+ ibs[0] = '[INFO] ' + ibs[0]
114
+ ibs[-1] = ibs[-1] + ' [/INFO]'
115
+ infobox = [i for j in infobox for i in j]
116
+ doc["ctx_sent"] = doc["query"] + [sep_token] + infobox
117
+ return docs
118
+ # create context type for other scenarios
119
+ res = []
120
+ for doc in docs:
121
+ ctx_len = context_length if context_length > 0 else len(doc)
122
+ # for the last sentences loop around to the beginning for context
123
+ padded_doc = doc + doc[:ctx_len]
124
+ for i in range(len(doc)):
125
+ res.append((
126
+ [s[0] for sent in padded_doc[i:i+ctx_len] for s in sent],
127
+ [s[pos_col_id] for sent in padded_doc[i:i+ctx_len] for s in sent] if pos_col_id > 0 else None,
128
+ [s[tag_col_id] for sent in padded_doc[i:i+ctx_len] for s in sent],
129
+ [len(sent) for sent in padded_doc[i:i+ctx_len]],
130
+ {} # dictionary for extra context
131
+ ))
132
+ return res
133
+
134
+
135
+ def calc_correct(sentence):
136
+ gold_chunks = []
137
+ parallel_chunks = []
138
+ pred_chunks = []
139
+ curr_gold_chunk = []
140
+ curr_parallel_chunk = []
141
+ curr_pred_chunk = []
142
+ prev_tag = None
143
+ for line in sentence:
144
+ _, _, _, gt, pt = line
145
+ curr_tag = None
146
+ if '-' in pt:
147
+ curr_tag = pt.split('-')[1]
148
+ if gt.startswith('B'):
149
+ if curr_gold_chunk:
150
+ gold_chunks.append(curr_gold_chunk)
151
+ parallel_chunks.append(curr_parallel_chunk)
152
+ curr_gold_chunk = [gt]
153
+ curr_parallel_chunk = [pt]
154
+ elif gt.startswith('I') or (pt.startswith('I') and curr_tag == prev_tag
155
+ and curr_gold_chunk):
156
+ curr_gold_chunk.append(gt)
157
+ curr_parallel_chunk.append(pt)
158
+ elif gt.startswith('O') and pt.startswith('O'):
159
+ if curr_gold_chunk:
160
+ gold_chunks.append(curr_gold_chunk)
161
+ parallel_chunks.append(curr_parallel_chunk)
162
+ curr_gold_chunk = []
163
+ curr_parallel_chunk = []
164
+ if pt.startswith('O'):
165
+ if curr_pred_chunk:
166
+ pred_chunks.append(curr_pred_chunk)
167
+ curr_pred_chunk = []
168
+ elif pt.startswith('B'):
169
+ if curr_pred_chunk:
170
+ pred_chunks.append(curr_pred_chunk)
171
+ curr_pred_chunk = [pt]
172
+ prev_tag = curr_tag
173
+ else:
174
+ if prev_tag is not None and curr_tag != prev_tag:
175
+ prev_tag = curr_tag
176
+ if curr_pred_chunk:
177
+ pred_chunks.append(curr_pred_chunk)
178
+ curr_pred_chunk = []
179
+ curr_pred_chunk.append(pt)
180
+
181
+ if curr_gold_chunk:
182
+ gold_chunks.append(curr_gold_chunk)
183
+ parallel_chunks.append(curr_parallel_chunk)
184
+ if curr_pred_chunk:
185
+ pred_chunks.append(curr_pred_chunk)
186
+ correct = sum([1 for gc, pc in zip(gold_chunks, parallel_chunks)
187
+ if not len([1 for g, p in zip(gc, pc) if g != p])])
188
+ correct_tagless = sum([1 for gc, pc in zip(gold_chunks, parallel_chunks)
189
+ if not len([1 for g, p in zip(gc, pc) if g[0] != p[0]])])
190
+ # return correct, gold_chunks, parallel_chunks, pred_chunks, ob1_correct, correct_tagless
191
+ return {'correct': correct,
192
+ 'correct_tagless': correct_tagless,
193
+ 'gold_count': len(gold_chunks),
194
+ 'pred_count': len(pred_chunks)}
195
+
196
+
197
+ def tag_sentences(sentences):
198
+ nlp = stanza.Pipeline(lang='en', processors='tokenize,pos', logging_level='WARNING')
199
+ tagged_sents = []
200
+ for sentence in sentences:
201
+ n = nlp(sentence)
202
+ tagged_sent = []
203
+ for s in n.sentences:
204
+ for w in s.words:
205
+ tagged_sent.append([w.text, w.upos])
206
+ tagged_sents.append(tagged_sent)
207
+ return tagged_sents
208
+
209
+
210
+ def extract_spans(sentence, tagless=False):
211
+ spans_positions = []
212
+ span_bounds = []
213
+ all_bounds = []
214
+ span_tags = []
215
+ curr_tag = None
216
+ curr_span = []
217
+ curr_span_start = -1
218
+ # span ids, span types
219
+ for i, token in enumerate(sentence):
220
+ if token.startswith('B'):
221
+ if curr_span:
222
+ spans_positions.append([curr_span, len(all_bounds)])
223
+ span_bounds.append([curr_span_start, i-1])
224
+ all_bounds.append([[curr_span_start, i - 1], 'E', len(all_bounds)])
225
+ if not tagless:
226
+ span_tags.append(token.split('-')[1])
227
+ curr_span = []
228
+ curr_tag = None
229
+ curr_span.append(token)
230
+ curr_tag = None if tagless else token.split('-')[1]
231
+ curr_span_start = i
232
+ elif token.startswith('I'):
233
+ if not tagless:
234
+ tag = token.split('-')[1]
235
+ if tag != curr_tag and curr_tag is not None:
236
+ spans_positions.append([curr_span, len(all_bounds)])
237
+ span_bounds.append([curr_span_start, i - 1])
238
+ span_tags.append(token.split('-')[1])
239
+ all_bounds.append([[curr_span_start, i - 1], 'E', len(all_bounds)])
240
+ curr_span = []
241
+ curr_tag = tag
242
+ curr_span_start = i
243
+ elif curr_tag is None:
244
+ curr_span = []
245
+ curr_tag = tag
246
+ curr_span_start = i
247
+ elif not curr_span:
248
+ curr_span_start = i
249
+ curr_span.append(token)
250
+ elif token.startswith('O') or token.startswith('-'):
251
+ if curr_span:
252
+ spans_positions.append([curr_span, len(all_bounds)])
253
+ span_bounds.append([curr_span_start, i-1])
254
+ all_bounds.append([[curr_span_start, i-1], 'E', len(all_bounds)])
255
+ curr_span = []
256
+ curr_tag = None
257
+ all_bounds.append([[i], 'W', len(all_bounds)])
258
+ # check if sentence ended with a span
259
+ if curr_span:
260
+ spans_positions.append([curr_span, len(all_bounds)])
261
+ span_bounds.append([curr_span_start, len(sentence) - 1])
262
+ all_bounds.append([[curr_span_start, len(sentence) - 1], 'E', len(all_bounds)])
263
+ tagged_bounds = [[loc[0][0].split('-')[1] if '-' in loc[0][0] else loc[0][0], bound]
264
+ for loc, bound in zip(spans_positions, span_bounds)]
265
+ return spans_positions, span_bounds, all_bounds, tagged_bounds
266
+
267
+
268
+ def ner_corpus_stats(corpus_path):
269
+ onto_train_cols = read_conll_ner(corpus_path)
270
+ tags = list(set([t.split('-')[1] for t in onto_train_cols[3] if '-' in t]))
271
+ onto_train_spans = [extract_spans([t[3] for t in sent])[3] for sent in
272
+ onto_train_cols[0]]
273
+ span_lens = [span[1][1] - span[1][0] + 1 for sent in onto_train_spans for
274
+ span in sent]
275
+
276
+ len_stats = [span_lens.count(i + 1) / len(span_lens) for i in
277
+ range(max(span_lens))]
278
+ flat_spans = [span for sent in onto_train_spans for span in sent]
279
+
280
+ tag_lens_dict = {k: [] for k in tags}
281
+ tag_counts_dict = {k: 0 for k in tags}
282
+ for span in flat_spans:
283
+ span_length = span[1][1] - span[1][0] + 1
284
+ span_tag = span[0][0].split('-')[1]
285
+ tag_lens_dict[span_tag].append(span_length)
286
+ tag_counts_dict[span_tag] += 1
287
+
288
+ x = list(tag_counts_dict.items())
289
+ x.sort(key=lambda l: l[1])
290
+ tag_counts = [list(l) for l in x]
291
+ for l in tag_counts:
292
+ l[1] = l[1] / len(span_lens)
293
+
294
+ tag_len_stats = {k: [v.count(i + 1) / len(v) for i in range(max(v))]
295
+ for k, v in tag_lens_dict.items()}
296
+ span_texts = [sent[span[1][0]:span[1][1] + 1]
297
+ for sent, spans in zip(onto_train_cols[0], onto_train_spans)
298
+ for span in spans]
299
+ span_pos = [[span[0][-1].split('-')[1], '_'.join(t[1] for t in span)]
300
+ for span in span_texts]
301
+ unique_pos = list(set([span[1] for span in span_pos]))
302
+ pos_dict = {k: 0 for k in unique_pos}
303
+ for span in span_pos:
304
+ pos_dict[span[1]] += 1
305
+ unique_pos.sort(key=lambda l: pos_dict[l], reverse=True)
306
+ pos_stats = [[p, pos_dict[p] / len(span_pos)] for p in unique_pos]
307
+ tag_pos_dict = {kt: {kp: 0 for kp in unique_pos} for kt in tags}
308
+ for span in span_pos:
309
+ tag_pos_dict[span[0]][span[1]] += 1
310
+ tag_pos_stats = {kt: [[p, tag_pos_dict[kt][p] / tag_counts_dict[kt]]
311
+ for p in unique_pos] for kt in tags}
312
+ for kt in tags:
313
+ tag_pos_stats[kt].sort(key=lambda l: l[1], reverse=True)
314
+
315
+ return len_stats, tag_len_stats, tag_counts, pos_stats, tag_pos_stats
316
+
317
+
318
+ def filter_by_max_ents(sentences, max_ent_length):
319
+ """
320
+ Filters a given list of sentences and only returns the sentences that have
321
+ named entities shorter than or equal to the given max_ent_length.
322
+
323
+ :param sentences: sentences in conll format as extracted by read_conll_ner
324
+ :param max_ent_length: The maximum number of tokens in an entity
325
+ :return: a lits of sentences
326
+ """
327
+ filtered_sents = []
328
+ for sent in sentences:
329
+ sent_span_lens = [s[1] - s[0] + 1
330
+ for s in extract_spans([t[3] for t in sent])[1]]
331
+ if not sent_span_lens or max(sent_span_lens) <= max_ent_length:
332
+ filtered_sents.append(sent)
333
+ return filtered_sents
334
+
335
+
336
+ def get_pred_ent_bounds(docs):
337
+ for doc in docs:
338
+ eb = []
339
+ count = 0
340
+ for p_eb in doc['pred_spans']:
341
+ if p_eb == 'B':
342
+ eb.append([count,count])
343
+ elif p_eb == 'I' and len(eb) > 0:
344
+ eb[-1][1] = count
345
+ count += 1
346
+ doc['pred_ent_bounds'] = eb
347
+ return docs
348
+
349
+ def enumerate_spans(batch):
350
+
351
+ enumerated_spans_batch = []
352
+
353
+ for idx in range(0, len(batch)):
354
+ sentence_length = batch[idx]
355
+ enumerated_spans = []
356
+ for x in range(len(sentence_length)):
357
+ for y in range(x, len(sentence_length)):
358
+ enumerated_spans.append([x,y])
359
+
360
+ enumerated_spans_batch.append(enumerated_spans)
361
+
362
+ return enumerated_spans_batch
363
+
364
+ def compact_span_enumeration(batch):
365
+ sentence_lengths = [len(b) for b in batch]
366
+ enumerated_spans = [[[x, y]
367
+ for y in range(0, sentence_length)
368
+ for x in range(sentence_length)]
369
+ for sentence_length in sentence_lengths]
370
+ return enumerated_spans
371
+
372
+ def preprocess_data(data):
373
+ clean_data = []
374
+ for sample in data:
375
+ clean_tokens = [araby.strip_tashkeel(token) for token in sample[0]]
376
+ clean_tokens = [araby.strip_tatweel(token) for token in clean_tokens]
377
+ clean_sample = [clean_tokens]
378
+ clean_sample.extend(sample[1:])
379
+ clean_data.append(clean_sample)
380
+ return clean_data
381
+
382
+
383
+ def generate_targets(enumerated_spans, sentences):
384
+ #### could be refactored into a helper function ####
385
+ extracted_spans= [extract_spans(sentence,True)[3] for sentence in sentences]
386
+ target_locations = []
387
+
388
+ for span in extracted_spans:
389
+ sentence_locations = []
390
+ for location in span:
391
+ sentence_locations.append(location[1])
392
+ target_locations.append(sentence_locations)
393
+
394
+ #### could be refactored into a helper function ####
395
+
396
+
397
+ targets= []
398
+
399
+ for span, location_list in zip(enumerated_spans, target_locations):
400
+ span_arr = np.zeros_like(span).tolist()
401
+ target_indices = [span.index(span_location) for
402
+ span_location in location_list]
403
+
404
+
405
+ for idx in target_indices:
406
+ span_arr[idx] =1
407
+
408
+ span_arr = [0 if x!=1 else x for x in span_arr]
409
+ targets.append(list(span_arr))
410
+
411
+ return targets
412
+
413
+ def label_tags(tags):
414
+ output_tags = []
415
+ for tag in tags:
416
+ if (tag == "O"):
417
+ output_tags.append(0)
418
+ else:
419
+ output_tags.append(1)
420
+ return output_tags
validate.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import re
3
+
4
+ import torch
5
+ from tqdm.auto import tqdm
6
+
7
+ from .network import EntNet
8
+ from .utils import read_conll_ner, split_conll_docs, create_context_data, extract_spans
9
+
10
+ use_cuda = torch.cuda.is_available()
11
+ device = torch.device("cuda" if use_cuda else "cpu")
12
+
13
+
14
+ def classify(model, sents, pos, batch_size):
15
+ model.eval()
16
+ result = []
17
+ for i in tqdm(range(0, len(sents), batch_size), desc='classifying... '):
18
+ tag_seqs = model(sentences=sents[i:i + batch_size],
19
+ pos=pos[i:i + batch_size])
20
+ result.extend(tag_seqs['pred_tags'])
21
+ # f1, p, r
22
+ return [[[w, t] for w, t in zip(s, r)] for s, r in zip(sents, result)]
23
+
24
+
25
+ def entities_from_token_classes(tokens):
26
+ ENTITY_BEGIN_REGEX = r"^B" # -(\w+)"
27
+ ENTITY_MIDDLE_REGEX = r"^I" # -(\w+)"
28
+
29
+ entities = []
30
+ current_entity = None
31
+ start_index_of_current_entity = 0
32
+ end_index_of_current_entity = 0
33
+ for i, kls in enumerate(tokens):
34
+ m = re.match(ENTITY_BEGIN_REGEX, kls)
35
+ if m is not None:
36
+ if current_entity is not None:
37
+ entities.append({
38
+ "type": current_entity,
39
+ "index": [start_index_of_current_entity,
40
+ end_index_of_current_entity]
41
+ })
42
+ # start of entity
43
+ current_entity = m.string.split('-')[1] if '-' in m.string else ''
44
+ start_index_of_current_entity = i
45
+ end_index_of_current_entity = i
46
+ continue
47
+
48
+ m = re.match(ENTITY_MIDDLE_REGEX, kls)
49
+ if current_entity is not None:
50
+ if m is None:
51
+ # after the end of this entity
52
+ entities.append({
53
+ "type": current_entity,
54
+ "index": [start_index_of_current_entity,
55
+ end_index_of_current_entity]
56
+ })
57
+ current_entity = None
58
+ continue
59
+ # in the middle of this entity
60
+ end_index_of_current_entity = i
61
+
62
+ # Add any remaining entity
63
+ if current_entity is not None:
64
+ entities.append({
65
+ "type": current_entity,
66
+ "index": [start_index_of_current_entity,
67
+ end_index_of_current_entity]
68
+ })
69
+
70
+ return entities
71
+
72
+
73
+ def calc_f1(targs, preds):
74
+ stat_dict = {
75
+ 'overall': {'unl_tp': 0, 'lab_tp': 0, 'targs': 0, 'preds': 0}
76
+ }
77
+
78
+ for sent_targs, sent_preds in zip(targs, preds):
79
+ stat_dict['overall']['targs'] += len(sent_targs)
80
+ stat_dict['overall']['preds'] += len(sent_preds)
81
+
82
+ for pred in sent_preds:
83
+ if pred['type'] not in stat_dict.keys():
84
+ stat_dict[pred['type']] = {'lab_tp': 0, 'targs': 0, 'preds': 0}
85
+ stat_dict[pred['type']]['preds'] += 1
86
+
87
+ for targ in sent_targs:
88
+ if targ['type'] not in stat_dict.keys():
89
+ stat_dict[targ['type']] = {'lab_tp': 0, 'targs': 0, 'preds': 0}
90
+ stat_dict[targ['type']]['targs'] += 1
91
+ # is there a span that matches exactly?
92
+ for pred in sent_preds:
93
+ if pred['index'][0] == targ['index'][0] and pred['index'][1] == targ['index'][1]:
94
+ stat_dict['overall']['unl_tp'] += 1
95
+ # if so do the tags match exactly?
96
+ if pred['type'] == targ['type']:
97
+ stat_dict['overall']['lab_tp'] += 1
98
+ stat_dict[targ['type']]['lab_tp'] += 1
99
+
100
+ for k in stat_dict.keys():
101
+ if k == 'overall':
102
+ stat_dict[k]['unl_p'] = stat_dict[k]['unl_tp'] / stat_dict[k]['preds'] if stat_dict[k]['preds'] else 0
103
+ stat_dict[k]['unl_r'] = stat_dict[k]['unl_tp'] / stat_dict[k]['targs'] if stat_dict[k]['targs'] else 0
104
+ stat_dict[k]['unl_f1'] = 2 * stat_dict[k]['unl_p'] * stat_dict[k]['unl_r'] / (
105
+ stat_dict[k]['unl_p'] + stat_dict[k]['unl_r']) if (
106
+ stat_dict[k]['unl_p'] + stat_dict[k]['unl_r']) else 0
107
+ stat_dict[k]['lab_p'] = stat_dict[k]['lab_tp'] / stat_dict[k]['preds'] if stat_dict[k]['preds'] else 0
108
+ stat_dict[k]['lab_r'] = stat_dict[k]['lab_tp'] / stat_dict[k]['targs'] if stat_dict[k]['targs'] else 0
109
+ stat_dict[k]['lab_f1'] = 2 * stat_dict[k]['lab_p'] * stat_dict[k]['lab_r'] / (
110
+ stat_dict[k]['lab_p'] + stat_dict[k]['lab_r']) if (stat_dict[k]['lab_p'] + stat_dict[k]['lab_r']) else 0
111
+ class_f1s = [v['lab_f1'] for k, v in stat_dict.items() if k != 'overall']
112
+ stat_dict['overall']['macro_lab_f1'] = sum(class_f1s) / len(class_f1s)
113
+ return stat_dict
114
+
115
+
116
+ def main(args):
117
+ global device
118
+ device = torch.device('cuda' if use_cuda else 'cpu')
119
+
120
+ test_columns = read_conll_ner(args.test_path)
121
+ test_docs = split_conll_docs(test_columns[0])
122
+ test_data = create_context_data(test_docs, args.context_size)
123
+
124
+ sents = [td[0] for td in test_data]
125
+ pos = [td[1] for td in test_data]
126
+
127
+ if len(args.model_path) > 1 or args.span_model_path is not None:
128
+ model = StagedEnsemble(model_paths=args.model_path, span_model_paths=args.span_model_path, device=device)
129
+ else:
130
+ model = EntNet.load_model(args.model_path[0], device=device)
131
+ model.to(device)
132
+
133
+ BATCH_SIZE = args.batch_size
134
+ res = classify(model, sents, pos, BATCH_SIZE)
135
+ targets = [td[2] for td in test_data]
136
+
137
+ targ_tags = [entities_from_token_classes(td[2]) for td in test_data]
138
+ pred_tags = [entities_from_token_classes([t[1] for t in r]) for r in res]
139
+ result = calc_f1(targ_tags, pred_tags)
140
+
141
+ print(f'Overall unlabelled - F1:{result["overall"]["unl_f1"]}, '
142
+ f'P:{result["overall"]["unl_p"]}, '
143
+ f'R:{result["overall"]["unl_r"]}')
144
+ print(f'Overall labelled - Micro F1:{result["overall"]["lab_f1"]}, '
145
+ f'P:{result["overall"]["lab_p"]}, '
146
+ f'R:{result["overall"]["lab_r"]}')
147
+ print(f'Overall labelled - Macro F1:{result["overall"]["macro_lab_f1"]}')
148
+ for k, v in result.items():
149
+ if k == 'overall':
150
+ continue
151
+ print(f'{k} - F1:{v["lab_f1"]}, P:{v["lab_p"]}, R:{v["lab_r"]}')
152
+
153
+
154
+ if __name__ == "__main__":
155
+ parser = argparse.ArgumentParser()
156
+ parser.add_argument('--model_path', type=str, nargs='+', default=None, required=True, help='')
157
+ parser.add_argument('--span_model_path', type=str, nargs='*', default=None, help='')
158
+ # parser.add_argument('--network_type', type=str,
159
+ # choices=['span', 'entity', 'joint'], required=True,
160
+ # default=None, help='If entity is chosen, a path to a '
161
+ # 'span model is required also')
162
+ parser.add_argument('--test_path', type=str, default=None, help='')
163
+ parser.add_argument('--context_size', type=int, default=1, help='')
164
+ parser.add_argument('--batch_size', type=int, default=8, help='')
165
+ # parser.add_argument('--cuda_id', type=int, default=0, help='')
166
+
167
+ args = parser.parse_args()
168
+ main(args)