ajit commited on
Commit
5775680
1 Parent(s): 3df8af2

Initial creation

Browse files
BatchInference.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import subprocess
3
+ #from pytorch_transformers import *
4
+ from transformers import *
5
+ import pdb
6
+ import operator
7
+ from collections import OrderedDict
8
+ import numpy as np
9
+ import argparse
10
+ import sys
11
+ import traceback
12
+ import string
13
+ import common as utils
14
+ import config_utils as cf
15
+ import requests
16
+ import json
17
+ import streamlit as st
18
+
19
+ # OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
20
+ import logging
21
+ logging.basicConfig(level=logging.INFO)
22
+
23
+
24
+ DEFAULT_TOP_K = 20
25
+ DEFAULT_CONFIG = "./server_config.json"
26
+ DEFAULT_MODEL_PATH='./'
27
+ DEFAULT_LABELS_PATH='./labels.txt'
28
+ DEFAULT_TO_LOWER=False
29
+ DESC_FILE="./common_descs.txt"
30
+ SPECIFIC_TAG=":__entity__"
31
+ MAX_TOKENIZED_SENT_LENGTH = 500 #additional buffer for CLS SEP and entity term
32
+
33
+ try:
34
+ from subprocess import DEVNULL # Python 3.
35
+ except ImportError:
36
+ DEVNULL = open(os.devnull, 'wb')
37
+
38
+
39
+ @st.cache()
40
+ def load_bert_model(model_name,to_lower):
41
+ try:
42
+ bert_tokenizer = BertTokenizer.from_pretrained(model_name,do_lower_case=to_lower)
43
+ bert_model = BertForMaskedLM.from_pretrained(model_name)
44
+ return bert_tokenizer,bert_model
45
+ except Exception as e:
46
+ pass
47
+
48
+ def read_descs(file_name):
49
+ ret_dict = {}
50
+ with open(file_name) as fp:
51
+ line = fp.readline().rstrip("\n")
52
+ if (len(line) >= 1):
53
+ ret_dict[line] = 1
54
+ while line:
55
+ line = fp.readline().rstrip("\n")
56
+ if (len(line) >= 1):
57
+ ret_dict[line] = 1
58
+ return ret_dict
59
+
60
+ def read_vocab(file_name):
61
+ l_vocab_dict = {}
62
+ o_vocab_dict = {}
63
+ with open(file_name) as fp:
64
+ for line in fp:
65
+ line = line.rstrip('\n')
66
+ if (len(line) > 0):
67
+ l_vocab_dict[line.lower()] = line #If there are multiple cased versions they will be collapsed into one. which is okay since we have the original saved. This is only used
68
+ #when a word is not found in its pristine form in the original list.
69
+ o_vocab_dict[line] = line
70
+ print("Read vocab file:",len(o_vocab_dict))
71
+ return o_vocab_dict,l_vocab_dict
72
+
73
+ def consolidate_labels(existing_node,new_labels,new_counts):
74
+ """Consolidates all the labels and counts for terms ignoring casing
75
+
76
+ For instance, egfr may not have an entity label associated with it
77
+ but eGFR and EGFR may have. So if input is egfr, then this function ensures
78
+ the combined entities set fo eGFR and EGFR is made so as to return that union
79
+ for egfr
80
+ """
81
+ new_dict = {}
82
+ existing_labels_arr = existing_node["label"].split('/')
83
+ existing_counts_arr = existing_node["counts"].split('/')
84
+ new_labels_arr = new_labels.split('/')
85
+ new_counts_arr = new_counts.split('/')
86
+ assert(len(existing_labels_arr) == len(existing_counts_arr))
87
+ assert(len(new_labels_arr) == len(new_counts_arr))
88
+ for i in range(len(existing_labels_arr)):
89
+ new_dict[existing_labels_arr[i]] = int(existing_counts_arr[i])
90
+ for i in range(len(new_labels_arr)):
91
+ if (new_labels_arr[i] in new_dict):
92
+ new_dict[new_labels_arr[i]] += int(new_counts_arr[i])
93
+ else:
94
+ new_dict[new_labels_arr[i]] = int(new_counts_arr[i])
95
+ sorted_d = OrderedDict(sorted(new_dict.items(), key=lambda kv: kv[1], reverse=True))
96
+ ret_labels_str = ""
97
+ ret_counts_str = ""
98
+ count = 0
99
+ for key in sorted_d:
100
+ if (count == 0):
101
+ ret_labels_str = key
102
+ ret_counts_str = str(sorted_d[key])
103
+ else:
104
+ ret_labels_str += '/' + key
105
+ ret_counts_str += '/' + str(sorted_d[key])
106
+ count += 1
107
+ return {"label":ret_labels_str,"counts":ret_counts_str}
108
+
109
+
110
+ def read_labels(labels_file):
111
+ terms_dict = OrderedDict()
112
+ lc_terms_dict = OrderedDict()
113
+ with open(labels_file,encoding="utf-8") as fin:
114
+ count = 1
115
+ for term in fin:
116
+ term = term.strip("\n")
117
+ term = term.split()
118
+ if (len(term) == 3):
119
+ terms_dict[term[2]] = {"label":term[0],"counts":term[1]}
120
+ lc_term = term[2].lower()
121
+ if (lc_term in lc_terms_dict):
122
+ lc_terms_dict[lc_term] = consolidate_labels(lc_terms_dict[lc_term],term[0],term[1])
123
+ else:
124
+ lc_terms_dict[lc_term] = {"label":term[0],"counts":term[1]}
125
+ count += 1
126
+ else:
127
+ print("Invalid line:",term)
128
+ assert(0)
129
+ print("count of labels in " + labels_file + ":", len(terms_dict))
130
+ return terms_dict,lc_terms_dict
131
+
132
+
133
+ class BatchInference:
134
+ def __init__(self, config_file,path,to_lower,patched,topk,abbrev,tokmod,vocab_path,labels_file,delimsep):
135
+ print("Model path:",path,"lower casing set to:",to_lower," is patched ", patched)
136
+ self.path = path
137
+ base_path = cf.read_config(config_file)["BASE_PATH"] if ("BASE_PATH" in cf.read_config(config_file)) else "./"
138
+ desc_file_path = cf.read_config(config_file)["DESC_FILE"] if ("DESC_FILE" in cf.read_config(config_file)) else DESC_FILE
139
+ self.labels_dict,self.lc_labels_dict = read_labels(labels_file)
140
+ #self.tokenizer = BertTokenizer.from_pretrained(path,do_lower_case=to_lower) ### Set this to to True for uncased models
141
+ #self.model = BertForMaskedLM.from_pretrained(path)
142
+ self.tokenizer, self.model = load_bert_model(path,to_lower)
143
+ self.model.eval()
144
+ #st.info("model loaded")
145
+ self.descs = read_descs(desc_file_path)
146
+ #st.info("descs loaded")
147
+ self.top_k = topk
148
+ self.patched = patched
149
+ self.abbrev = abbrev
150
+ self.tokmod = tokmod
151
+ self.delimsep = delimsep
152
+ self.truncated_fp = open(base_path + "truncated_sentences.txt","a")
153
+ self.always_log_fp = open(base_path + "CI_LOGS.txt","a")
154
+ if (cf.read_config(config_file)["USE_CLS"] == "1"): #Models like Bert base cased return same prediction for CLS regardless of input. So ignore CLS
155
+ print("************** USE CLS: Turned ON for this model. ******* ")
156
+ self.use_cls = True
157
+ else:
158
+ print("************** USE CLS: Turned OFF for this model. ******* ")
159
+ self.use_cls = False
160
+ if (cf.read_config(config_file)["LOG_DESCS"] == "1"):
161
+ self.log_descs = True
162
+ self.ci_fp = open(base_path + "log_ci_predictions.txt","w")
163
+ self.cs_fp = open(base_path + "log_cs_predictions.txt","w")
164
+ else:
165
+ self.log_descs = False
166
+ self.pos_server_url = cf.read_config(config_file)["POS_SERVER_URL"]
167
+ #st.info("Attemting to load vocab file")
168
+ if (tokmod):
169
+ self.o_vocab_dict,self.l_vocab_dict = read_vocab(vocab_path + "/vocab.txt")
170
+ else:
171
+ self.o_vocab_dict = {}
172
+ self.l_vocab_dict = {}
173
+ # st.info("Constructor complete")
174
+ #pdb.set_trace()
175
+
176
+ def dispatch_request(self,url):
177
+ max_retries = 10
178
+ attempts = 0
179
+ while True:
180
+ try:
181
+ r = requests.get(url,timeout=1000)
182
+ if (r.status_code == 200):
183
+ return r
184
+ except:
185
+ print("Request:", url, " failed. Retrying...")
186
+ attempts += 1
187
+ if (attempts >= max_retries):
188
+ print("Request:", url, " failed")
189
+ break
190
+
191
+ def modify_text_to_match_vocab(self,text):
192
+ ret_arr = []
193
+ text = text.split()
194
+ for word in text:
195
+ if (word in self.o_vocab_dict):
196
+ ret_arr.append(word)
197
+ else:
198
+ if (word.lower() in self.l_vocab_dict):
199
+ ret_arr.append(self.l_vocab_dict[word.lower()])
200
+ else:
201
+ ret_arr.append(word)
202
+ return ' '.join(ret_arr)
203
+
204
+ #This is bad hack for prototyping - parsing from text output as opposed to json
205
+ def extract_POS(self,text):
206
+ arr = text.split('\n')
207
+ if (len(arr) > 0):
208
+ start_pos = 0
209
+ for i,line in enumerate(arr):
210
+ if (len(line) > 0):
211
+ start_pos += 1
212
+ continue
213
+ else:
214
+ break
215
+ #print(arr[start_pos:])
216
+ terms_arr = []
217
+ for i,line in enumerate(arr[start_pos:]):
218
+ terms = line.split('\t')
219
+ if (len(terms) == 5):
220
+ #print(terms)
221
+ terms_arr.append(terms)
222
+ return terms_arr
223
+
224
+ def masked_word_first_letter_capitalize(self,entity):
225
+ arr = entity.split()
226
+ ret_arr = []
227
+ for term in arr:
228
+ if (len(term) > 1 and term[0].islower() and term[1].islower()):
229
+ ret_arr.append(term[0].upper() + term[1:])
230
+ else:
231
+ ret_arr.append(term)
232
+ return ' '.join(ret_arr)
233
+
234
+
235
+ def gen_single_phrase_sentences(self,terms_arr,span_arr):
236
+ sentence_template = "%s is a entity"
237
+ #print(span_arr)
238
+ sentences = []
239
+ singleton_spans_arr = []
240
+ run_index = 0
241
+ entity = ""
242
+ singleton_span = []
243
+ while (run_index < len(span_arr)):
244
+ if (span_arr[run_index] == 1):
245
+ while (run_index < len(span_arr)):
246
+ if (span_arr[run_index] == 1):
247
+ #print(terms_arr[run_index][WORD_POS],end=' ')
248
+ if (len(entity) == 0):
249
+ entity = terms_arr[run_index][utils.WORD_POS]
250
+ else:
251
+ entity = entity + " " + terms_arr[run_index][utils.WORD_POS]
252
+ singleton_span.append(1)
253
+ run_index += 1
254
+ else:
255
+ break
256
+ #print()
257
+ for i in sentence_template.split():
258
+ if (i != "%s"):
259
+ singleton_span.append(0)
260
+ entity = self.masked_word_first_letter_capitalize(entity)
261
+ if (self.tokmod):
262
+ entity = self.modify_text_to_match_vocab(entity)
263
+ sentence = sentence_template % entity
264
+ sentences.append(sentence)
265
+ singleton_spans_arr.append(singleton_span)
266
+ #print(sentence)
267
+ #rint(singleton_span)
268
+ entity = ""
269
+ singleton_span = []
270
+ else:
271
+ run_index += 1
272
+ return sentences,singleton_spans_arr
273
+
274
+
275
+
276
+ def gen_padded_sentence(self,text,max_tokenized_sentence_length,tokenized_text_arr,orig_tokenized_length_arr,indexed_tokens_arr,attention_mask_arr,to_replace):
277
+ if (to_replace):
278
+ text_arr = text.split()
279
+ new_text_arr = []
280
+ for i in range(len(text_arr)):
281
+ if (text_arr[i] == "entity" ):
282
+ new_text_arr.append( "[MASK]")
283
+ else:
284
+ new_text_arr.append(text_arr[i])
285
+ text = ' '.join(new_text_arr)
286
+ text = '[CLS] ' + text + ' [SEP]'
287
+ tokenized_text = self.tokenizer.tokenize(text)
288
+ indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
289
+ tok_length = len(indexed_tokens)
290
+ max_tokenized_sentence_length = max_tokenized_sentence_length if tok_length <= max_tokenized_sentence_length else tok_length
291
+ indexed_tokens_arr.append(indexed_tokens)
292
+ attention_mask_arr.append([1]*tok_length)
293
+ tokenized_text_arr.append(tokenized_text)
294
+ orig_tokenized_length_arr.append(tokenized_text)
295
+ return max_tokenized_sentence_length
296
+
297
+
298
+
299
+ def find_entity(self,word):
300
+ entities = self.labels_dict
301
+ lc_entities = self.lc_labels_dict
302
+ in_vocab = False
303
+ #words = self.filter_glue_words(words) #do not filter glue words anymore. Let them pass through
304
+ l_word = word.lower()
305
+ if l_word.isdigit():
306
+ ret_label = "MEASURE"
307
+ ret_counts = str(1)
308
+ elif (word in entities):
309
+ ret_label = entities[word]["label"]
310
+ ret_counts = entities[word]["counts"]
311
+ in_vocab = True
312
+ elif (l_word in entities):
313
+ ret_label = entities[l_word]["label"]
314
+ ret_counts = entities[l_word]["counts"]
315
+ in_vocab = True
316
+ elif (l_word in lc_entities):
317
+ ret_label = lc_entities[l_word]["label"]
318
+ ret_counts = lc_entities[l_word]["counts"]
319
+ in_vocab = True
320
+ else:
321
+ ret_label = "OTHER"
322
+ ret_counts = "1"
323
+ if (ret_label == "OTHER"):
324
+ ret_label = "UNTAGGED_ENTITY"
325
+ ret_counts = "1"
326
+ #print(word,ret_label,ret_counts)
327
+ return ret_label,ret_counts,in_vocab
328
+
329
+ #This is just a trivial hack for consistency of CI prediction of numbers
330
+ def override_ci_number_predictions(self,masked_sent):
331
+ words = masked_sent.split()
332
+ words_count = len(words)
333
+ if (len(words) == 4 and words[words_count-1] == "entity" and words[words_count -2] == "a" and words[words_count -3] == "is" and words[0].isnumeric()): #only integers skipped
334
+ return True,"two","1","NUMBER"
335
+ else:
336
+ return False,"","",""
337
+
338
+ def override_ci_for_vocab_terms(self,masked_sent):
339
+ words = masked_sent.split()
340
+ words_count = len(words)
341
+ if (len(words) == 4 and words[words_count-1] == "entity" and words[words_count -2] == "a" and words[words_count -3] == "is"):
342
+ entity,entity_count,in_vocab = self.find_entity(words[0])
343
+ if (in_vocab):
344
+ return True,words[0],entity_count,entity
345
+ return False,"","",""
346
+
347
+
348
+
349
+ def normalize_sent(self,sent):
350
+ normalized_tokens = "!\"%();?[]`{}"
351
+ end_tokens = "!,.:;?"
352
+ sent = sent.rstrip()
353
+ if (len(sent) > 1):
354
+ if (self.delimsep):
355
+ for i in range(len(normalized_tokens)):
356
+ sent = sent.replace(normalized_tokens[i],' ' + normalized_tokens[i] + ' ')
357
+ sent = sent.rstrip()
358
+ if (not sent.endswith(":__entity__")):
359
+ last_char = sent[-1]
360
+ if (last_char not in end_tokens): #End all sentences with a period if not already present in sentence.
361
+ sent = sent + ' . '
362
+ print("Normalized sent",sent)
363
+ return sent
364
+
365
+ def truncate_sent_if_too_long(self,text):
366
+ truncated_count = 0
367
+ orig_sent = text
368
+ while (True):
369
+ tok_text = '[CLS] ' + text + ' [SEP]'
370
+ tokenized_text = self.tokenizer.tokenize(tok_text)
371
+ if (len(tokenized_text) < MAX_TOKENIZED_SENT_LENGTH):
372
+ break
373
+ text = ' '.join(text.split()[:-1])
374
+ truncated_count += 1
375
+ if (truncated_count > 0):
376
+ print("Input sentence was truncated by: ", truncated_count, " tokens")
377
+ self.truncated_fp.write("Input sentence was truncated by: " + str(truncated_count) + " tokens\n")
378
+ self.truncated_fp.write(orig_sent + "\n")
379
+ self.truncated_fp.write(text + "\n\n")
380
+ return text
381
+
382
+
383
+ def get_descriptors(self,sent,pos_arr):
384
+ '''
385
+ Batched creation of descriptors given a sentence.
386
+ 1) Find noun phrases to tag in a sentence if user did not explicitly tag.
387
+ 2) Create 'N' CS and CI sentences if there are N phrases to tag. Total 2*N sentences
388
+ 3) Create a batch padding all sentences to the maximum sentence length.
389
+ 4) Perform inference on batch
390
+ 5) Return json of descriptors for the ooriginal sentence as well as all CI sentences
391
+ '''
392
+ #Truncate sent if the tokenized sent is longer than max sent length
393
+ #st.info("in get descriptors")
394
+ sent = self.truncate_sent_if_too_long(sent)
395
+ #This is a modification of input text to words in vocab that match it in case insensitive manner.
396
+ #This is *STILL* required when we are using subwords too for prediction. The prediction quality is still better.
397
+ #An example is Mesothelioma is caused by exposure to asbestos. The quality of prediction is better when Mesothelioma is not split by lowercasing with A100 model
398
+ if (self.tokmod):
399
+ sent = self.modify_text_to_match_vocab(sent)
400
+
401
+ #The input sentence is normalized. Specifically all input is terminated with a punctuation if not already present. Also some of the punctuation marks are separated from text if glued to a word(disabled by default for test set sync)
402
+ sent = self.normalize_sent(sent)
403
+
404
+ #Step 1. Find entities to tag if user did not explicitly tag terms
405
+ #All noun phrases are tagged for prediction
406
+ if (SPECIFIC_TAG in sent):
407
+ terms_arr = utils.set_POS_based_on_entities(sent)
408
+ else:
409
+ if (pos_arr is None):
410
+ assert(0)
411
+ url = self.pos_server_url + sent.replace('"','\'')
412
+ r = self.dispatch_request(url)
413
+ terms_arr = self.extract_POS(r.text)
414
+ else:
415
+ # st.info("Reusing Pos arr")
416
+ terms_arr = pos_arr
417
+
418
+ print(terms_arr)
419
+ #Note span arr only contains phrases in the input that need to be tagged - not the span of all phrases in sentences
420
+ #Step 2. Create N CS sentences
421
+ #This returns masked sentences for all positions
422
+ main_sent_arr,masked_sent_arr,span_arr = utils.detect_masked_positions(terms_arr)
423
+ ignore_cs = True if (len(masked_sent_arr) == 1 and len(masked_sent_arr[0]) == 2 and masked_sent_arr[0][0] == "__entity__" and masked_sent_arr[0][1] == ".") else False #This is a boundary condition to avoid using cs if the input is just trying to get entity type for a phrase. There is no sentence context in that case.
424
+
425
+
426
+ #Step 2. Create N CI sentences
427
+ singleton_sentences,not_used_singleton_spans_arr = self.gen_single_phrase_sentences(terms_arr,span_arr)
428
+
429
+
430
+ #We now have 2*N sentences
431
+ max_tokenized_sentence_length = 0
432
+ tokenized_text_arr = []
433
+ indexed_tokens_arr = []
434
+ attention_mask_arr = []
435
+ all_sentences_arr = []
436
+ orig_tokenized_length_arr = []
437
+ assert(len(masked_sent_arr) == len(singleton_sentences))
438
+ for ci_s,cs_s in zip(singleton_sentences,masked_sent_arr):
439
+ all_sentences_arr.append(ci_s)
440
+ max_tokenized_sentence_length = self.gen_padded_sentence(ci_s,max_tokenized_sentence_length,tokenized_text_arr,orig_tokenized_length_arr,indexed_tokens_arr,attention_mask_arr,True)
441
+ cs_s = ' '.join(cs_s).replace("__entity__","entity")
442
+ all_sentences_arr.append(cs_s)
443
+ max_tokenized_sentence_length = self.gen_padded_sentence(cs_s,max_tokenized_sentence_length,tokenized_text_arr,orig_tokenized_length_arr,indexed_tokens_arr,attention_mask_arr,True)
444
+
445
+
446
+ #pad all sentences with length less than max sentence length. This includes the full sentence too since we used indexed_tokens_arr
447
+ for i in range(len(indexed_tokens_arr)):
448
+ padding = [self.tokenizer.pad_token_id]*(max_tokenized_sentence_length - len(indexed_tokens_arr[i]))
449
+ att_padding = [0]*(max_tokenized_sentence_length - len(indexed_tokens_arr[i]))
450
+ if (len(padding) > 0):
451
+ indexed_tokens_arr[i].extend(padding)
452
+ attention_mask_arr[i].extend(att_padding)
453
+
454
+
455
+ assert(len(main_sent_arr) == len(span_arr))
456
+ assert(len(all_sentences_arr) == len(indexed_tokens_arr))
457
+ assert(len(all_sentences_arr) == len(attention_mask_arr))
458
+ assert(len(all_sentences_arr) == len(tokenized_text_arr))
459
+ assert(len(all_sentences_arr) == len(orig_tokenized_length_arr))
460
+ # Convert inputs to PyTorch tensors
461
+ tokens_tensor = torch.tensor(indexed_tokens_arr)
462
+ attention_tensors = torch.tensor(attention_mask_arr)
463
+
464
+
465
+ print("Input:",sent)
466
+ ret_obj = OrderedDict()
467
+ with torch.no_grad():
468
+ predictions = self.model(tokens_tensor, attention_mask=attention_tensors)
469
+ for sent_index in range(len(predictions[0])):
470
+
471
+ #print("*** Current sentence ***",all_sentences_arr[sent_index])
472
+ if (self.log_descs):
473
+ fp = self.cs_fp if sent_index %2 != 0 else self.ci_fp
474
+ fp.write("\nCurrent sentence: " + all_sentences_arr[sent_index] + "\n")
475
+ prediction = "ci_prediction" if (sent_index %2 == 0 ) else "cs_prediction"
476
+ out_index = int(sent_index/2) + 1
477
+ if (out_index not in ret_obj):
478
+ ret_obj[out_index] = {}
479
+ assert(prediction not in ret_obj[out_index])
480
+ ret_obj[out_index][prediction] = {}
481
+ ret_obj[out_index][prediction]["sentence"] = all_sentences_arr[sent_index]
482
+ curr_sent_arr = []
483
+ ret_obj[out_index][prediction]["descs"] = curr_sent_arr
484
+
485
+ for word in range(len(tokenized_text_arr[sent_index])):
486
+ if (word == len(tokenized_text_arr[sent_index]) - 1): # SEP is skipped for CI and CS
487
+ continue
488
+ if (sent_index %2 == 0 and (word != 0 and word != len(orig_tokenized_length_arr[sent_index]) - 2)): #For all CI sentences pick only the neighbors of CLS and the last word of the sentence (X is a entity)
489
+ #if (sent_index %2 == 0 and (word != 0 and word != len(orig_tokenized_length_arr[sent_index]) - 2) and word != len(orig_tokenized_length_arr[sent_index]) - 3): #For all CI sentences - just pick CLS, "a" and "entity"
490
+ #if (sent_index %2 == 0 and (word != 0 and (word == len(orig_tokenized_length_arr[sent_index]) - 4))): #For all CI sentences pick ALL terms excluding "is" in "X is a entity"
491
+ continue
492
+ if (sent_index %2 == 0 and (word == 0 and not self.use_cls)): #This is for models like bert base cased where we cant use CLS - it is the same for all words.
493
+ continue
494
+
495
+ if (sent_index %2 != 0 and tokenized_text_arr[sent_index][word] != "[MASK]"): # for all CS sentences skip all terms except the mask position
496
+ continue
497
+
498
+
499
+ results_dict = {}
500
+ masked_index = word
501
+ #pick all model predictions for current position word
502
+ if (self.patched):
503
+ for j in range(len(predictions[0][0][sent_index][masked_index])):
504
+ tok = tokenizer.convert_ids_to_tokens([j])[0]
505
+ results_dict[tok] = float(predictions[0][0][sent_index][masked_index][j].tolist())
506
+ else:
507
+ for j in range(len(predictions[0][sent_index][masked_index])):
508
+ tok = self.tokenizer.convert_ids_to_tokens([j])[0]
509
+ results_dict[tok] = float(predictions[0][sent_index][masked_index][j].tolist())
510
+ k = 0
511
+ #sort it - big to small
512
+ sorted_d = OrderedDict(sorted(results_dict.items(), key=lambda kv: kv[1], reverse=True))
513
+
514
+
515
+ #print("********* Top predictions for token: ",tokenized_text_arr[sent_index][word])
516
+ if (self.log_descs):
517
+ fp.write("********* Top predictions for token: " + tokenized_text_arr[sent_index][word] + "\n")
518
+ if (sent_index %2 == 0): #For CI sentences, just pick half for CLS and entity position to match with CS counts
519
+ if (self.use_cls): #If we are not using [CLS] for models like BBC, then take all top k from the entity prediction
520
+ top_k = self.top_k/2
521
+ else:
522
+ top_k = self.top_k
523
+ else:
524
+ top_k = self.top_k
525
+ #Looping through each descriptor prediction for a position and picking it up subject to some conditions
526
+ for index in sorted_d:
527
+ #if (index in string.punctuation or index.startswith('##') or len(index) == 1 or index.startswith('.') or index.startswith('[')):
528
+ if index.lower() in self.descs: #these have almost no entity info - glue words like "the","a"
529
+ continue
530
+ #if (index in string.punctuation or len(index) == 1 or index.startswith('.') or index.startswith('[') or index.startswith("#")):
531
+ if (index in string.punctuation or len(index) == 1 or index.startswith('.') or index.startswith('[')):
532
+ continue
533
+ if (index.startswith("#")): #subwords suggest model is trying to predict a multi word term that generally tends to be noisy. So penalize. Count and skip
534
+ k += 1
535
+ continue
536
+ #print(index,round(float(sorted_d[index]),4))
537
+ if (sent_index % 2 != 0):
538
+ #CS predictions
539
+ entity,entity_count,dummy = self.find_entity(index)
540
+ if (self.log_descs):
541
+ self.cs_fp.write(index + " " + entity + " " + entity_count + " " + str(round(float(sorted_d[index]),4)) + "\n")
542
+ if (not ignore_cs):
543
+ curr_sent_arr.append({"desc":index,"e":entity,"e_count":entity_count,"v":str(round(float(sorted_d[index]),4))})
544
+ if (all_sentences_arr[sent_index].strip().rstrip(".").strip().endswith("entity")):
545
+ self.always_log_fp.write(' '.join(all_sentences_arr[sent_index].split()[:-1]) + " " + index + " :__entity__\n")
546
+ else:
547
+ #CI predictions of the form X is a entity
548
+ entity,entity_count,dummy = self.find_entity(index) #index is one of the predicted descs for the [CLS]/[MASK] psition
549
+ number_override,override_index,override_entity_count,override_entity = self.override_ci_number_predictions(all_sentences_arr[sent_index]) #Note this override just uses the sentence to override all descs
550
+ if (number_override): #note the prediction for this position still takes the prediction float values model returns
551
+ index = override_index
552
+ entity_count = override_entity_count
553
+ entity = override_entity
554
+ else:
555
+ if (not self.use_cls or word != 0):
556
+ override,override_index,override_entity_count,override_entity = self.override_ci_for_vocab_terms(all_sentences_arr[sent_index]) #this also uses the sentence to override, ignoring descs, except reusing the prediction score
557
+ if (override): #note the prediction for this position still takes the prediction float values model returns
558
+ index = override_index
559
+ entity_count = override_entity_count
560
+ entity = override_entity
561
+ k = top_k #just add this override once. We dont have to add this override for each descripor and inundate downstream NER with the same signature
562
+
563
+ if (self.log_descs):
564
+ self.ci_fp.write(index + " " + entity + " " + entity_count + " " + str(round(float(sorted_d[index]),4)) + "\n")
565
+ curr_sent_arr.append({"desc":index,"e":entity,"e_count":entity_count,"v":str(round(float(sorted_d[index]),4))})
566
+ #if (index != "two" and not index.startswith("#") and not all_sentences_arr[sent_index].strip().startswith("is ")):
567
+ if (index != "two" and not all_sentences_arr[sent_index].strip().startswith("is ")):
568
+ self.always_log_fp.write(' '.join(all_sentences_arr[sent_index].split()[:-1]) + " " + index + " :__entity__\n")
569
+ k += 1
570
+ if (k >= top_k):
571
+ break
572
+ #print()
573
+ #print(ret_obj)
574
+ #print(ret_obj)
575
+ #st.info("Enf. of prediciton")
576
+ #pdb.set_trace()
577
+ #final_obj = {"terms_arr":main_sent_arr,"span_arr":span_arr,"descs_and_entities":ret_obj,"all_sentences":all_sentences_arr}
578
+ final_obj = {"input":sent,"terms_arr":main_sent_arr,"span_arr":span_arr,"descs_and_entities":ret_obj}
579
+ if (self.log_descs):
580
+ self.ci_fp.flush()
581
+ self.cs_fp.flush()
582
+ self.always_log_fp.flush()
583
+ self.truncated_fp.flush()
584
+ return final_obj
585
+
586
+
587
+ test_arr = [
588
+ "ajit? is an engineer .",
589
+ "Sam:__entity__ Malone:__entity__ .",
590
+ "1. Jesper:__entity__ Ronnback:__entity__ ( Sweden:__entity__ ) 25.76 points",
591
+ "He felt New York has a chance:__entity__ to win this year's competition .",
592
+ "The new omicron variant could increase the likelihood that people will need a fourth coronavirus vaccine dose earlier than expected, executives at Prin dummy:__entity__ said Wednesday .",
593
+ "The new omicron variant could increase the likelihood that people will need a fourth coronavirus vaccine dose earlier than expected, executives at pharmaceutical:__entity__ giant:__entity__ Pfizer:__entity__ said Wednesday .",
594
+ "The conditions:__entity__ in the camp were very poor",
595
+ "Imatinib:__entity__ is used to treat nsclc",
596
+ "imatinib:__entity__ is used to treat nsclc",
597
+ "imatinib:__entity__ mesylate:__entity__ is used to treat nsclc",
598
+ "Staten is a :__entity__",
599
+ "John is a :__entity__",
600
+ "I met my best friend at eighteen :__entity__",
601
+ "I met my best friend at Parkinson's",
602
+ "e",
603
+ "Bandolier - Budgie ' , a free itunes app for ipad , iphone and ipod touch , released in December 2011 , tells the story of the making of Bandolier in the band 's own words - including an extensive audio interview with Burke Shelley",
604
+ "The portfolio manager of the new cryptocurrency firm underwent a bone marrow biopsy: for AML:__entity__:",
605
+ "Coronavirus:__entity__ disease 2019 (COVID-19) is a contagious disease caused by severe acute respiratory syndrome coronavirus 2 (SARS-CoV-2). The first known case was identified in Wuhan, China, in December 2019.[7] The disease has since spread worldwide, leading to an ongoing pandemic.[8]Symptoms of COVID-19 are variable, but often include fever,[9] cough, headache,[10] fatigue, breathing difficulties, and loss of smell and taste.[11][12][13] Symptoms may begin one to fourteen days after exposure to the virus. At least a third of people who are infected do not develop noticeable symptoms.[14] Of those people who develop symptoms noticeable enough to be classed as patients, most (81%) develop mild to moderate symptoms (up to mild pneumonia), while 14% develop severe symptoms (dyspnea, hypoxia, or more than 50% lung involvement on imaging), and 5% suffer critical symptoms (respiratory failure, shock, or multiorgan dysfunction).[15] Older people are at a higher risk of developing severe symptoms. Some people continue to experience a range of effects (long COVID) for months after recovery, and damage to organs has been observed.[16] Multi-year studies are underway to further investigate the long-term effects of the disease.[16]COVID-19 transmits when people breathe in air contaminated by droplets and small airborne particles containing the virus. The risk of breathing these in is highest when people are in close proximity, but they can be inhaled over longer distances, particularly indoors. Transmission can also occur if splashed or sprayed with contaminated fluids in the eyes, nose or mouth, and, rarely, via contaminated surfaces. People remain contagious for up to 20 days, and can spread the virus even if they do not develop symptoms.[17][18]Several testing methods have been developed to diagnose the disease. The standard diagnostic method is by detection of the virus' nucleic acid by real-time reverse transcription polymerase chain reaction (rRT-PCR), transcription-mediated amplification (TMA), or by reverse transcription loop-mediated isothermal amplification (RT-LAMP) from a nasopharyngeal swab.Several COVID-19 vaccines have been approved and distributed in various countries, which have initiated mass vaccination campaigns. Other preventive measures include physical or social distancing, quarantining, ventilation of indoor spaces, covering coughs and sneezes, hand washing, and keeping unwashed hands away from the face. The use of face masks or coverings has been recommended in public settings to minimize the risk of transmissions. While work is underway to develop drugs that inhibit the virus, the primary treatment is symptomatic. Management involves the treatment of symptoms, supportive care, isolation, and experimental measures.",
606
+ "imatinib was used to treat Michael Jackson . ",
607
+ "eg .",
608
+ "mesothelioma is caused by exposure to organic :__entity__",
609
+ "Mesothelioma is caused by exposure to asbestos:__entity__",
610
+ "Asbestos is a highly :__entity__",
611
+ "Fyodor:__entity__ Mikhailovich:__entity__ Dostoevsky:__entity__ was treated for Parkinsons:__entity__ and later died of lung carcinoma",
612
+ "Fyodor:__entity__ Mikhailovich:__entity__ Dostoevsky:__entity__",
613
+ "imatinib was used to treat Michael:__entity__ Jackson:__entity__",
614
+ "Ajit flew to Boston:__entity__",
615
+ "Ajit:__entity__ flew to Boston",
616
+ "A eGFR below 60:__entity__ indicates chronic kidney disease",
617
+ "imatinib was used to treat Michael Jackson",
618
+ "Ajit Valath:__entity__ Rajasekharan is an engineer at nFerence headquartered in Cambrigde MA",
619
+ "imatinib:__entity__",
620
+ "imatinib",
621
+ "iplimumab:__entity__",
622
+ "iplimumab",
623
+ "engineer:__entity__",
624
+ "engineer",
625
+ "Complications include peritonsillar:__entity__ abscess::__entity__",
626
+ "Imatinib was the first signal transduction inhibitor (STI,, used in a clinical setting. It prevents a BCR-ABL protein from exerting its role in the oncogenic pathway in chronic:__entity__ myeloid:__entity__ leukemia:__entity__ (CML,",
627
+ "Imatinib was the first signal transduction inhibitor (STI,, used in a clinical setting. It prevents a BCR-ABL protein from exerting its role in the oncogenic pathway in chronic myeloid leukemia (CML,",
628
+ "Imatinib was the first signal transduction inhibitor (STI,, used in a clinical setting. It prevents a BCR-ABL protein from exerting its role in the oncogenic pathway in chronic:__entity__ myeloid:___entity__ leukemia:__entity__ (CML,",
629
+ "Ajit Rajasekharan is an engineer:__entity__ at nFerence:__entity__",
630
+ "Imatinib was the first signal transduction inhibitor (STI,, used in a clinical setting. It prevents a BCR-ABL protein from exerting its role in the oncogenic pathway in chronic myeloid leukemia (CML,",
631
+ "Ajit:__entity__ Rajasekharan:__entity__ is an engineer",
632
+ "Imatinib:__entity__ was the first signal transduction inhibitor (STI,, used in a clinical setting. It prevents a BCR-ABL protein from exerting its role in the oncogenic pathway in chronic myeloid leukemia (CML,",
633
+ "Ajit Valath Rajasekharan is an engineer at nFerence headquartered in Cambrigde MA",
634
+ "Ajit:__entity__ Valath Rajasekharan is an engineer:__entity__ at nFerence headquartered in Cambrigde MA",
635
+ "Ajit:__entity__ Valath:__entity__ Rajasekharan is an engineer:__entity__ at nFerence headquartered in Cambrigde MA",
636
+ "Ajit:__entity__ Valath:__entity__ Rajasekharan:__entity__ is an engineer:__entity__ at nFerence headquartered in Cambrigde MA",
637
+ "Ajit Raj is an engineer:__entity__ at nFerence",
638
+ "Ajit Valath:__entity__ Rajasekharan is an engineer:__entity__ at nFerence headquartered in Cambrigde:__entity__ MA",
639
+ "Ajit Valath Rajasekharan is an engineer:__entity__ at nFerence headquartered in Cambrigde:__entity__ MA",
640
+ "Ajit Valath Rajasekharan is an engineer:__entity__ at nFerence headquartered in Cambrigde MA",
641
+ "Ajit Valath Rajasekharan is an engineer at nFerence headquartered in Cambrigde MA",
642
+ "Ajit:__entity__ Rajasekharan:__entity__ is an engineer at nFerence:__entity__",
643
+ "Imatinib mesylate is used to treat non small cell lung cancer",
644
+ "Imatinib mesylate is used to treat :__entity__",
645
+ "Imatinib is a term:__entity__",
646
+ "nsclc is a term:__entity__",
647
+ "Ajit Rajasekharan is a term:__entity__",
648
+ "ajit rajasekharan is a term:__entity__",
649
+ "John Doe is a term:__entity__"
650
+ ]
651
+
652
+
653
+ def test_sentences(singleton,iter_val):
654
+ with open("debug.txt","w") as fp:
655
+ for test in iter_val:
656
+ test = test.rstrip('\n')
657
+ fp.write(test + "\n")
658
+ print(test)
659
+ out = singleton.get_descriptors(test)
660
+ print(out)
661
+ fp.write(json.dumps(out,indent=4))
662
+ fp.flush()
663
+ print()
664
+ pdb.set_trace()
665
+
666
+
667
+ if __name__ == '__main__':
668
+ parser = argparse.ArgumentParser(description='BERT descriptor service given a sentence. The word to be masked is specified as the special token entity ',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
669
+ parser.add_argument('-config', action="store", dest="config", default=DEFAULT_CONFIG,help='config file path')
670
+ parser.add_argument('-model', action="store", dest="model", default=DEFAULT_MODEL_PATH,help='BERT pretrained models, or custom model path')
671
+ parser.add_argument('-input', action="store", dest="input", default="",help='Optional input file with sentences. If not specified, assumed to be canned sentence run (default behavior)')
672
+ parser.add_argument('-topk', action="store", dest="topk", default=DEFAULT_TOP_K,type=int,help='Number of neighbors to display')
673
+ parser.add_argument('-tolower', dest="tolower", action='store_true',help='Convert tokens to lowercase. Set to True only for uncased models')
674
+ parser.add_argument('-no-tolower', dest="tolower", action='store_false',help='Convert tokens to lowercase. Set to True only for uncased models')
675
+ parser.set_defaults(tolower=False)
676
+ parser.add_argument('-patched', dest="patched", action='store_true',help='Is pytorch code patched to harvest [CLS]')
677
+ parser.add_argument('-no-patched', dest="patched", action='store_false',help='Is pytorch code patched to harvest [CLS]')
678
+ parser.add_argument('-abbrev', dest="abbrev", action='store_true',help='Just output pivots - not all neighbors')
679
+ parser.add_argument('-no-abbrev', dest="abbrev", action='store_false',help='Just output pivots - not all neighbors')
680
+ parser.add_argument('-tokmod', dest="tokmod", action='store_true',help='Modify input token casings to match vocab - meaningful only for cased models')
681
+ parser.add_argument('-no-tokmod', dest="tokmod", action='store_false',help='Modify input token casings to match vocab - meaningful only for cased models')
682
+ parser.add_argument('-vocab', action="store", dest="vocab", default=DEFAULT_MODEL_PATH,help='Path to vocab file. This is required only if tokmod is true')
683
+ parser.add_argument('-labels', action="store", dest="labels", default=DEFAULT_LABELS_PATH,help='Path to labels file. This returns labels also')
684
+ parser.add_argument('-delimsep', dest="delimsep", action='store_true',help='Modify input tokens where delimiters are stuck to tokens. Turned off by default to be in sync with test sets')
685
+ parser.add_argument('-no-delimsep', dest="delimsep", action='store_true',help='Modify input tokens where delimiters are stuck to tokens. Turned off by default to be in sync with test sets')
686
+ parser.set_defaults(tolower=False)
687
+ parser.set_defaults(patched=False)
688
+ parser.set_defaults(abbrev=True)
689
+ parser.set_defaults(tokmod=True)
690
+ parser.set_defaults(delimsep=False)
691
+
692
+ results = parser.parse_args()
693
+ try:
694
+ singleton = BatchInference(results.config,results.model,results.tolower,results.patched,results.topk,results.abbrev,results.tokmod,results.vocab,results.labels,results.delimsep)
695
+ print("To lower casing is set to:",results.tolower)
696
+ if (len(results.input) == 0):
697
+ print("Canned test mode")
698
+ test_sentences(singleton,test_arr)
699
+ else:
700
+ print("Batch file test mode")
701
+ fp = open(results.input)
702
+ test_sentences(singleton,fp)
703
+
704
+ except:
705
+ print("Unexpected error:", sys.exc_info()[0])
706
+ traceback.print_exc(file=sys.stdout)
707
+
aggregate_server_json.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ import threading
3
+ import time
4
+ import math
5
+ import sys
6
+ import pdb
7
+ import requests
8
+ import urllib.parse
9
+ from common import *
10
+ import config_utils as cf
11
+ import json
12
+ from collections import OrderedDict
13
+ import argparse
14
+ import numpy as np
15
+
16
+
17
+ MASK = ":__entity__"
18
+ RESULT_MASK = "NER_FINAL_RESULTS:"
19
+ DEFAULT_CONFIG = "./ensemble_config.json"
20
+
21
+ DEFAULT_TEST_BATCH_FILE="bootstrap_test_set.txt"
22
+ NER_OUTPUT_FILE="ner_output.txt"
23
+ DEFAULT_THRESHOLD = 1 #1 standard deviation from nean - for cross over prediction
24
+
25
+ actions_arr = []
26
+
27
+ class AggregateNER:
28
+ def __init__(self,config_file):
29
+ global actions_arr
30
+ base_path = cf.read_config(config_file)["BASE_PATH"] if ("BASE_PATH" in cf.read_config(config_file)) else "./"
31
+ self.error_fp = open(base_path + "failed_queries_log.txt","a")
32
+ self.rfp = open(base_path + "query_response_log.txt","a")
33
+ self.query_log_fp = open(base_path + "query_logs.txt","a")
34
+ self.inferred_entities_log_fp = open(base_path + "inferred_entities_log.txt","a")
35
+ self.threshold = DEFAULT_THRESHOLD #TBD read this from confg. cf.read_config()["CROSS_OVER_THRESHOLD_SIGMA"]
36
+ self.servers = cf.read_config(config_file)["NER_SERVERS"]
37
+ actions_arr = [
38
+ {"url":cf.read_config(config_file)["actions_arr"][0]["url"],"desc":cf.read_config(config_file)["actions_arr"][0]["desc"], "precedence":cf.read_config(config_file)["bio_precedence_arr"],"common":cf.read_config(config_file)["common_entities_arr"]},
39
+ {"url":cf.read_config(config_file)["actions_arr"][1]["url"],"desc":cf.read_config(config_file)["actions_arr"][1]["desc"],"precedence":cf.read_config(config_file)["phi_precedence_arr"],"common":cf.read_config(config_file)["common_entities_arr"]},
40
+ ]
41
+
42
+ def add_term_punct(self,sent):
43
+ if (len(sent) > 1):
44
+ end_tokens = "!,.:;?"
45
+ last_char = sent[-1]
46
+ if (last_char not in end_tokens): #End all sentences with a period if not already present in sentence.
47
+ sent = sent + ' . '
48
+ print("End punctuated sent:",sent)
49
+ return sent
50
+
51
+ def fetch_all(self,inp,model_results_arr):
52
+
53
+ self.query_log_fp.write(inp+"\n")
54
+ self.query_log_fp.flush()
55
+ inp = self.add_term_punct(inp)
56
+ results = model_results_arr
57
+ #print(json.dumps(results,indent=4))
58
+
59
+ #this updates results with ensembled results
60
+ results = self.ensemble_processing(inp,results)
61
+
62
+ return_stat = "Failed" if len(results["ensembled_ner"]) == 0 else "Success"
63
+ results["stats"] = { "Ensemble server count" : str(len(model_results_arr)), "return_status": return_stat}
64
+
65
+ self.rfp.write( "\n" + json.dumps(results,indent=4))
66
+ self.rfp.flush()
67
+ return results
68
+
69
+
70
+ def get_conflict_resolved_entity(self,results,term_index,terms_count,servers_arr):
71
+ pos_index = str(term_index + 1)
72
+ s1_entity = extract_main_entity(results,0,pos_index)
73
+ s2_entity = extract_main_entity(results,1,pos_index)
74
+ span_count1 = get_span_info(results,0,term_index,terms_count)
75
+ span_count2 = get_span_info(results,1,term_index,terms_count)
76
+ if(span_count1 != span_count2):
77
+ print("Both input spans dont match. This is the effect of normalized casing that is model specific. Picking min span length")
78
+ span_count1 = span_count1 if span_count1 <= span_count2 else span_count2
79
+ if (s1_entity == s2_entity):
80
+ server_index = 0 if (s1_entity in servers_arr[0]["precedence"]) else 1
81
+ if (s1_entity != "O"):
82
+ print("Both servers agree on prediction for term:",results[0]["ner"][pos_index]["term"],":",s1_entity)
83
+ return server_index,span_count1,-1
84
+ else:
85
+ print("Servers do not agree on prediction for term:",results[0]["ner"][pos_index]["term"],":",s1_entity,s2_entity)
86
+ if (s2_entity == "O"):
87
+ print("Server 2 returned O. Picking server 1")
88
+ return 0,span_count1,-1
89
+ if (s1_entity == "O"):
90
+ print("Server 1 returned O. Picking server 2")
91
+ return 1,span_count2,-1
92
+ #Both the servers dont agree on their predictions. First server is BIO server. Second is PHI
93
+ #Examine both server predictions.
94
+ #Case 1: If just one of them makes a single prediction, then just pick that - it indicates one model is confident while the other isnt.
95
+ #Else.
96
+ # If the top prediction of one of them is a cross prediction, then again drop that prediction and pick the server being cross predicted.
97
+ # Else. Return both predictions, but with the higher confidence prediction first
98
+ #Case 2: Both dont cross predict. Then just return both predictions with higher confidence prediction listed first
99
+ #Cross prediction is checked only for predictions a server makes ABOVE prediction mean.
100
+ picked_server_index,cross_prediction_count = self.pick_single_server_if_possible(results,term_index,servers_arr)
101
+ return picked_server_index,span_count1,cross_prediction_count
102
+
103
+ def pick_single_server_if_possible(self,results,term_index,servers_arr):
104
+ '''
105
+ Return param : index of picked server
106
+ '''
107
+ pos_index = str(term_index + 1)
108
+ predictions_dict = {}
109
+ orig_cs_predictions_dict = {}
110
+ single_prediction_count = 0
111
+ single_prediction_server_index = -1
112
+ for server_index in range(len(results)):
113
+ if (pos_index in results[server_index]["entity_distribution"]):
114
+ predictions = self.get_predictions_above_threshold(results[server_index]["entity_distribution"][pos_index])
115
+ predictions_dict[server_index] = predictions #This is used below to only return top server prediction
116
+
117
+ orig_cs_predictions = self.get_predictions_above_threshold(results[server_index]["orig_cs_prediction_details"][pos_index])
118
+ orig_cs_predictions_dict[server_index] = orig_cs_predictions #this is used below for cross prediction determination since it is just a CS prediction
119
+ #single_prediction_count += 1 if (len(orig_cs_predictions) == 1) else 0
120
+ #if (len(orig_cs_predictions) == 1):
121
+ # single_prediction_server_index = server_index
122
+ if (single_prediction_count == 1):
123
+ is_included = is_included_in_server_entities(orig_cs_predictions_dict[single_prediction_server_index],servers_arr[single_prediction_server_index],False)
124
+ if(is_included == False) :
125
+ print("This is an odd case of single server prediction, that is a cross over")
126
+ ret_index = 0 if single_prediction_server_index == 1 else 1
127
+ return ret_index,-1
128
+ else:
129
+ print("Returning the index of single prediction server")
130
+ return single_prediction_server_index,-1
131
+ elif (single_prediction_count == 2):
132
+ print("Both have single predictions")
133
+ cross_predictions = {}
134
+ cross_prediction_count = 0
135
+ for server_index in range(len(results)):
136
+ if (pos_index in results[server_index]["entity_distribution"]):
137
+ is_included = is_included_in_server_entities(orig_cs_predictions_dict[server_index],servers_arr[server_index],False)
138
+ cross_predictions[server_index] = not is_included
139
+ cross_prediction_count += 1 if not is_included else 0
140
+ if (cross_prediction_count == 2):
141
+ #this is an odd case of both cross predicting with high confidence. Not sure if we will ever come here.
142
+ print("*********** BOTH servers are cross predicting! ******")
143
+ return self.pick_top_server_prediction(predictions_dict),2
144
+ elif (cross_prediction_count == 0):
145
+ #Neither are cross predecting
146
+ print("*********** BOTH servers have single predictions within their domain - returning both ******")
147
+ return self.pick_top_server_prediction(predictions_dict),2
148
+ else:
149
+ print("Returning just the server that is not cross predicting, dumping the cross prediction")
150
+ ret_index = 1 if cross_predictions[0] == True else 0 #Given a server cross predicts, return the other server index
151
+ return ret_index,-1
152
+ else:
153
+ print("*** Both servers have multiple predictions above mean")
154
+ #both have multiple predictions above mean
155
+ cross_predictions = {}
156
+ strict_cross_predictions = {}
157
+ cross_prediction_count = 0
158
+ strict_cross_prediction_count = 0
159
+ for server_index in range(len(results)):
160
+ if (pos_index in results[server_index]["entity_distribution"]):
161
+ is_included = is_included_in_server_entities(orig_cs_predictions_dict[server_index],servers_arr[server_index],False)
162
+ strict_is_included = strict_is_included_in_server_entities(orig_cs_predictions_dict[server_index],servers_arr[server_index],False)
163
+ cross_predictions[server_index] = not is_included
164
+ strict_cross_predictions[server_index] = not strict_is_included
165
+ cross_prediction_count += 1 if not is_included else 0
166
+ strict_cross_prediction_count += 1 if not strict_is_included else 0
167
+ if (cross_prediction_count == 2):
168
+ print("*********** BOTH servers are ALSO cross predicting and have multiple predictions above mean ******")
169
+ return self.pick_top_server_prediction(predictions_dict),2
170
+ elif (cross_prediction_count == 0):
171
+ print("*********** BOTH servers are ALSO predicting within their domain ******")
172
+ #if just one of them is predicting in the common set, then just pick the server that is predicting in its primary set.
173
+ #if (strict_cross_prediction_count == 1):
174
+ # ret_index = 1 if (0 not in strict_cross_predictions or strict_cross_predictions[0] == True) else 0 #Given a server cross predicts, return the other server index
175
+ # return ret_index,-1
176
+ #else:
177
+ # return self.pick_top_server_prediction(predictions_dict),2
178
+ return self.pick_top_server_prediction(predictions_dict),2
179
+ else:
180
+ print("Returning just the server that is not cross predicting, dumping the cross prediction. This is mainly to reduce the noise in prefix predictions that show up in CS context predictions")
181
+ ret_index = 1 if (0 not in cross_predictions or cross_predictions[0] == True) else 0 #Given a server cross predicts, return the other server index
182
+ return ret_index,-1
183
+ #print("*********** One of them is also cross predicting ******")
184
+ #return self.pick_top_server_prediction(predictions_dict),2
185
+
186
+
187
+
188
+ def pick_top_server_prediction(self,predictions_dict):
189
+ '''
190
+ '''
191
+ if (len(predictions_dict) != 2):
192
+ return 0
193
+ assert(len(predictions_dict) == 2)
194
+ return 0 if (predictions_dict[0][0]["conf"] >= predictions_dict[1][0]["conf"]) else 1
195
+
196
+
197
+ def get_predictions_above_threshold(self,predictions):
198
+ dist = predictions["cs_distribution"]
199
+ sum_predictions = 0
200
+ ret_arr = []
201
+ if(len(dist) != 0):
202
+ mean_score = 1.0/len(dist) #input is a prob distriubution. so sum is 1
203
+ else:
204
+ mean_score = 0
205
+ #sum_deviation = 0
206
+ #for node in dist:
207
+ # sum_deviation += (mean_score - node["confidence"])*(mean_score - node["confidence"])
208
+ #variance = sum_deviation/len(dist)
209
+ #std_dev = math.sqrt(variance)
210
+ #threshold = mean_score + std_dev*self.threshold #default is 1 standard deviation from mean
211
+ threshold = mean_score
212
+ pick_count = 1
213
+ for node in dist:
214
+ if (node["confidence"] >= threshold):
215
+ ret_arr.append({"e":node["e"],"conf":node["confidence"]})
216
+ pick_count += 1
217
+ else:
218
+ break #this is a reverse sorted list. So no need to check anymore
219
+ if (len(dist) > 0):
220
+ assert(len(ret_arr) > 0)
221
+ return ret_arr
222
+
223
+ def check_if_entity_in_arr(self,entity,arr):
224
+ for node in arr:
225
+ if (entity == node["e"]):
226
+ return True
227
+ return False
228
+
229
+ def gen_resolved_entity(self,results,server_index,pivot_index,run_index,cross_prediction_count,servers_arr):
230
+ if (cross_prediction_count == 1 or cross_prediction_count == -1):
231
+ #This is the case where we are emitting just one server prediction. In this case, if CS and consolidated dont match, emit both
232
+ if (pivot_index in results[server_index]["orig_cs_prediction_details"]):
233
+ if (len(results[server_index]["orig_cs_prediction_details"][pivot_index]['cs_distribution']) == 0):
234
+ #just use the ci prediction in this case. This happens only for boundary cases of a single entity in a sentence and there is no context
235
+ orig_cs_entity = results[server_index]["orig_ci_prediction_details"][pivot_index]['cs_distribution'][0]
236
+ else:
237
+ orig_cs_entity = results[server_index]["orig_cs_prediction_details"][pivot_index]['cs_distribution'][0]
238
+ orig_ci_entity = results[server_index]["orig_ci_prediction_details"][pivot_index]['cs_distribution'][0]
239
+ m1 = orig_cs_entity["e"].split('[')[0]
240
+ m1_ci = orig_ci_entity["e"].split('[')[0]
241
+ is_ci_included = True if (m1_ci in servers_arr[server_index]["precedence"]) else False
242
+ consolidated_entity = results[server_index]["ner"][pivot_index]
243
+ m2,dummy = prefix_strip(consolidated_entity["e"].split('[')[0])
244
+ if (m1 != m2):
245
+ #if we come here consolidated is not same as cs prediction. So we emit both consolidated and cs
246
+ ret_obj = results[server_index]["ner"][run_index].copy()
247
+ dummy,prefix = prefix_strip(ret_obj["e"])
248
+ n1 = flip_category(orig_cs_entity)
249
+ n1["e"] = prefix + n1["e"]
250
+ n2 = flip_category(consolidated_entity)
251
+ ret_obj["e"] = n2["e"] + "/" + n1["e"]
252
+ return ret_obj
253
+ else:
254
+ #if we come here consolidated is same as cs prediction. So we try to either use ci or the second cs prediction if ci is out of domain
255
+ if (m1 != m1_ci):
256
+ #CS and CI are not same
257
+ if (is_ci_included):
258
+ #Emity both CS and CI
259
+ ret_obj = results[server_index]["ner"][run_index].copy()
260
+ dummy,prefix = prefix_strip(ret_obj["e"])
261
+ n1 = flip_category(orig_cs_entity)
262
+ n1["e"] = prefix + n1["e"]
263
+ n2 = flip_category(orig_ci_entity)
264
+ n2["e"] = prefix + n2["e"]
265
+ ret_obj["e"] = n1["e"] + "/" + n2["e"]
266
+ return ret_obj
267
+ else:
268
+ #We come here for the case where CI is not in server list. So we pick the second cs as an option if meaningful
269
+ if (len(results[server_index]["orig_cs_prediction_details"][pivot_index]['cs_distribution']) >= 2):
270
+ ret_arr = self.get_predictions_above_threshold(results[server_index]["orig_cs_prediction_details"][pivot_index])
271
+ orig_cs_second_entity = results[server_index]["orig_cs_prediction_details"][pivot_index]['cs_distribution'][1]
272
+ m2_cs = orig_cs_second_entity["e"].split('[')[0]
273
+ is_cs_included = True if (m2_cs in servers_arr[server_index]["precedence"]) else False
274
+ is_cs_included = True #Disabling cs included check. If prediction above threshold is cross prediction, then letting it through
275
+ assert (m2_cs != m1)
276
+ if (is_cs_included and self.check_if_entity_in_arr(m2_cs,ret_arr)):
277
+ ret_obj = results[server_index]["ner"][run_index].copy()
278
+ dummy,prefix = prefix_strip(ret_obj["e"])
279
+ n1 = flip_category(orig_cs_second_entity)
280
+ n1["e"] = prefix + n1["e"]
281
+ n2 = flip_category(orig_cs_entity)
282
+ n2["e"] = prefix + n2["e"]
283
+ ret_obj["e"] = n2["e"] + "/" + n1["e"]
284
+ return ret_obj
285
+ else:
286
+ return flip_category(results[server_index]["ner"][run_index])
287
+ else:
288
+ return flip_category(results[server_index]["ner"][run_index])
289
+ else:
290
+ #here cs and ci are same. So use two cs predictions if meaningful
291
+ if (len(results[server_index]["orig_cs_prediction_details"][pivot_index]['cs_distribution']) >= 2):
292
+ ret_arr = self.get_predictions_above_threshold(results[server_index]["orig_cs_prediction_details"][pivot_index])
293
+ orig_cs_second_entity = results[server_index]["orig_cs_prediction_details"][pivot_index]['cs_distribution'][1]
294
+ m2_cs = orig_cs_second_entity["e"].split('[')[0]
295
+ is_cs_included = True if (m2_cs in servers_arr[server_index]["precedence"]) else False
296
+ is_cs_included = True #Disabling cs included check. If prediction above threshold is cross prediction, then letting it through
297
+ assert (m2_cs != m1)
298
+ if (is_cs_included and self.check_if_entity_in_arr(m2_cs,ret_arr)):
299
+ ret_obj = results[server_index]["ner"][run_index].copy()
300
+ dummy,prefix = prefix_strip(ret_obj["e"])
301
+ n1 = flip_category(orig_cs_second_entity)
302
+ n1["e"] = prefix + n1["e"]
303
+ n2 = flip_category(orig_cs_entity)
304
+ n2["e"] = prefix + n2["e"]
305
+ ret_obj["e"] = n2["e"] + "/" + n1["e"]
306
+ return ret_obj
307
+ else:
308
+ return flip_category(results[server_index]["ner"][run_index])
309
+ else:
310
+ return flip_category(results[server_index]["ner"][run_index])
311
+ else:
312
+ return flip_category(results[server_index]["ner"][run_index])
313
+ else:
314
+ #Case where both servers dont match
315
+ ret_obj = results[server_index]["ner"][run_index].copy()
316
+ #ret_obj["e"] = results[0]["ner"][run_index]["e"] + "/" + results[1]["ner"][run_index]["e"]
317
+ index2 = 1 if server_index == 0 else 0 #this is the index of the dominant server with hihgher prediction confidence
318
+ n1 = flip_category(results[server_index]["ner"][run_index])
319
+ n2 = flip_category(results[index2]["ner"][run_index])
320
+ ret_obj["e"] = n1["e"] + "/" + n2["e"]
321
+ return ret_obj
322
+
323
+
324
+ def confirm_same_size_responses(self,sent,results):
325
+ count = 0
326
+ for i in range(len(results)):
327
+ if ("ner" in results[i]):
328
+ ner = results[i]["ner"]
329
+ else:
330
+ print("Server",i," returned invalid response;",results[i])
331
+ self.error_fp.write("Server " + str(i) + " failed for query: " + sent + "\n")
332
+ self.error_fp.flush()
333
+ return 0
334
+ if(count == 0):
335
+ assert(len(ner) > 0)
336
+ count = len(ner)
337
+ else:
338
+ if (count != len(ner)):
339
+ print("Warning. The return sizes of both servers do not match. This must be truncated sentence, where tokenization causes different length truncations. Using min length")
340
+ count = count if count < len(ner) else len(ner)
341
+ return count
342
+
343
+
344
+ def get_ensembled_entities(self,sent,results,servers_arr):
345
+ ensembled_ner = OrderedDict()
346
+ orig_cs_predictions = OrderedDict()
347
+ orig_ci_predictions = OrderedDict()
348
+ ensembled_conf = OrderedDict()
349
+ ambig_ensembled_conf = OrderedDict()
350
+ ensembled_ci = OrderedDict()
351
+ ensembled_cs = OrderedDict()
352
+ ambig_ensembled_ci = OrderedDict()
353
+ ambig_ensembled_cs = OrderedDict()
354
+ print("Ensemble candidates")
355
+ terms_count = self.confirm_same_size_responses(sent,results)
356
+ if (terms_count == 0):
357
+ return ensembled_ner,ensembled_conf,ensembled_ci,ensembled_cs,ambig_ensembled_conf,ambig_ensembled_ci,ambig_ensembled_cs,orig_cs_predictions,orig_ci_predictions
358
+ assert(len(servers_arr) == len(results))
359
+ term_index = 0
360
+ while (term_index < terms_count):
361
+ pos_index = str(term_index + 1)
362
+ assert(len(servers_arr) == 2) #TBD. Currently assumes two servers in prototype to see if this approach works. To be extended to multiple servers
363
+ server_index,span_count,cross_prediction_count = self.get_conflict_resolved_entity(results,term_index,terms_count,servers_arr)
364
+ pivot_index = str(term_index + 1)
365
+ for span_index in range(span_count):
366
+ run_index = str(term_index + 1 + span_index)
367
+ ensembled_ner[run_index] = self.gen_resolved_entity(results,server_index,pivot_index,run_index,cross_prediction_count,servers_arr)
368
+ if (run_index in results[server_index]["entity_distribution"]):
369
+ ensembled_conf[run_index] = results[server_index]["entity_distribution"][run_index]
370
+ ensembled_conf[run_index]["e"] = strip_prefixes(ensembled_ner[run_index]["e"]) #this is to make sure the same tag can be taken from NER result or this structure.
371
+ #When both server responses are required, just return the details of first server for now
372
+ ensembled_ci[run_index] = results[server_index]["ci_prediction_details"][run_index]
373
+ ensembled_cs[run_index] = results[server_index]["cs_prediction_details"][run_index]
374
+ orig_cs_predictions[run_index] = results[server_index]["orig_cs_prediction_details"][run_index]
375
+ orig_ci_predictions[run_index] = results[server_index]["orig_ci_prediction_details"][run_index]
376
+
377
+ if (cross_prediction_count == 0 or cross_prediction_count == 2): #This is an ambiguous prediction. Send both server responses
378
+ second_server = 1 if server_index == 0 else 1
379
+ if (run_index in results[second_server]["entity_distribution"]): #It may not be present if the B/I tags are out of sync from servers.
380
+ ambig_ensembled_conf[run_index] = results[second_server]["entity_distribution"][run_index]
381
+ ambig_ensembled_conf[run_index]["e"] = ensembled_ner[run_index]["e"] #this is to make sure the same tag can be taken from NER result or this structure.
382
+ ambig_ensembled_ci[run_index] = results[second_server]["ci_prediction_details"][run_index]
383
+ if (ensembled_ner[run_index]["e"] != "O"):
384
+ self.inferred_entities_log_fp.write(results[0]["ner"][run_index]["term"] + " " + ensembled_ner[run_index]["e"] + "\n")
385
+ term_index += span_count
386
+ self.inferred_entities_log_fp.flush()
387
+ return ensembled_ner,ensembled_conf,ensembled_ci,ensembled_cs,ambig_ensembled_conf,ambig_ensembled_ci,ambig_ensembled_cs,orig_cs_predictions,orig_ci_predictions
388
+
389
+
390
+
391
+ def ensemble_processing(self,sent,results):
392
+ global actions_arr
393
+ ensembled_ner,ensembled_conf,ci_details,cs_details,ambig_ensembled_conf,ambig_ci_details,ambig_cs_details,orig_cs_predictions,orig_ci_predictions = self.get_ensembled_entities(sent,results,actions_arr)
394
+ final_ner = OrderedDict()
395
+ final_ner["ensembled_ner"] = ensembled_ner
396
+ final_ner["ensembled_prediction_details"] = ensembled_conf
397
+ final_ner["ci_prediction_details"] = ci_details
398
+ final_ner["cs_prediction_details"] = cs_details
399
+ final_ner["ambig_prediction_details_conf"] = ambig_ensembled_conf
400
+ final_ner["ambig_prediction_details_ci"] = ambig_ci_details
401
+ final_ner["ambig_prediction_details_cs"] = ambig_cs_details
402
+ final_ner["orig_cs_prediction_details"] = orig_cs_predictions
403
+ final_ner["orig_ci_prediction_details"] = orig_ci_predictions
404
+ #final_ner["individual"] = results
405
+ return final_ner
406
+
407
+
408
+
409
+
410
+ class myThread (threading.Thread):
411
+ def __init__(self, url,param,desc):
412
+ threading.Thread.__init__(self)
413
+ self.url = url
414
+ self.param = param
415
+ self.desc = desc
416
+ self.results = {}
417
+ def run(self):
418
+ print ("Starting " + self.url + self.param)
419
+ escaped_url = self.url + self.param.replace("#","-") #TBD. This is a nasty hack for client side handling of #. To be fixed. For some reason, even replacing with parse.quote or just with %23 does not help. The fragment after # is not sent to server. Works just fine in wget with %23
420
+ print("ESCAPED:",escaped_url)
421
+ out = requests.get(escaped_url)
422
+ try:
423
+ self.results = json.loads(out.text,object_pairs_hook=OrderedDict)
424
+ except:
425
+ print("Empty response from server for input:",self.param)
426
+ self.results = json.loads("{}",object_pairs_hook=OrderedDict)
427
+ self.results["server"] = self.desc
428
+ print ("Exiting " + self.url + self.param)
429
+
430
+
431
+
432
+ # Create new threads
433
+ def create_workers(inp_dict,inp):
434
+ threads_arr = []
435
+ for i in range(len(inp_dict)):
436
+ threads_arr.append(myThread(inp_dict[i]["url"],inp,inp_dict[i]["desc"]))
437
+ return threads_arr
438
+
439
+ def start_workers(threads_arr):
440
+ for thread in threads_arr:
441
+ thread.start()
442
+
443
+ def wait_for_completion(threads_arr):
444
+ for thread in threads_arr:
445
+ thread.join()
446
+
447
+ def get_results(threads_arr):
448
+ results = []
449
+ for thread in threads_arr:
450
+ results.append(thread.results)
451
+ return results
452
+
453
+
454
+
455
+ def prefix_strip(term):
456
+ prefix = ""
457
+ if (term.startswith("B_") or term.startswith("I_")):
458
+ prefix = term[:2]
459
+ term = term[2:]
460
+ return term,prefix
461
+
462
+ def strip_prefixes(term):
463
+ split_entities = term.split('/')
464
+ if (len(split_entities) == 2):
465
+ term1,dummy = prefix_strip(split_entities[0])
466
+ term2,dummy = prefix_strip(split_entities[1])
467
+ return term1 + '/' + term2
468
+ else:
469
+ assert(len(split_entities) == 1)
470
+ term1,dummy = prefix_strip(split_entities[0])
471
+ return term1
472
+
473
+
474
+ #This hack is simply done for downstream API used for UI displays the entity instead of the class. Details has all additional info
475
+ def flip_category(obj):
476
+ new_obj = obj.copy()
477
+ entity_type_arr = obj["e"].split("[")
478
+ if (len(entity_type_arr) > 1):
479
+ term = entity_type_arr[0]
480
+ if (term.startswith("B_") or term.startswith("I_")):
481
+ prefix = term[:2]
482
+ new_obj["e"] = prefix + entity_type_arr[1].rstrip("]") + "[" + entity_type_arr[0][2:] + "]"
483
+ else:
484
+ new_obj["e"] = entity_type_arr[1].rstrip("]") + "[" + entity_type_arr[0] + "]"
485
+ return new_obj
486
+
487
+
488
+ def extract_main_entity(results,server_index,pos_index):
489
+ main_entity = results[server_index]["ner"][pos_index]["e"].split('[')[0]
490
+ main_entity,dummy = prefix_strip(main_entity)
491
+ return main_entity
492
+
493
+
494
+ def get_span_info(results,server_index,term_index,terms_count):
495
+ pos_index = str(term_index + 1)
496
+ entity = results[server_index]["ner"][pos_index]["e"]
497
+ span_count = 1
498
+ if (entity.startswith("I_")):
499
+ print("Skipping an I tag for server:",server_index,". This has to be done because of mismatched span because of model specific casing normalization that changes POS tagging. This happens only for sentencees user does not explicirly tag with ':__entity__'")
500
+ return span_count
501
+ assert(not entity.startswith("I_"))
502
+ if (entity.startswith("B_")):
503
+ term_index += 1
504
+ while(term_index < terms_count):
505
+ pos_index = str(term_index + 1)
506
+ entity = results[server_index]["ner"][pos_index]["e"]
507
+ if (entity == "O"):
508
+ break
509
+ span_count += 1
510
+ term_index += 1
511
+ return span_count
512
+
513
+ def is_included_in_server_entities(predictions,s_arr,check_first_only):
514
+ for entity in predictions:
515
+ entity = entity['e'].split('[')[0]
516
+ if ((entity not in s_arr["precedence"]) and (entity not in s_arr["common"])): #do not treat the presence of an entity in common as a cross over
517
+ return False
518
+ if (check_first_only):
519
+ return True #Just check the top prediction for inclusion in the new semantics
520
+ return True
521
+
522
+ def strict_is_included_in_server_entities(predictions,s_arr,check_first_only):
523
+ for entity in predictions:
524
+ entity = entity['e'].split('[')[0]
525
+ if ((entity not in s_arr["precedence"])): #do not treat the presence of an entity in common as a cross over
526
+ return False
527
+ if (check_first_only):
528
+ return True #Just check the top prediction for inclusion in the new semantics
529
+ return True
530
+
531
+
532
+
533
+ if __name__ == '__main__':
534
+ parser = argparse.ArgumentParser(description='main NER for a single model ',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
535
+ parser.add_argument('-input', action="store", dest="input",default=DEFAULT_TEST_BATCH_FILE,help='Input file for batch run option')
536
+ parser.add_argument('-config', action="store", dest="config", default=DEFAULT_CONFIG,help='config file path')
537
+ parser.add_argument('-output', action="store", dest="output",default=NER_OUTPUT_FILE,help='Output file for batch run option')
538
+ parser.add_argument('-option', action="store", dest="option",default="canned",help='Valid options are canned,batch,interactive. canned - test few canned sentences used in medium artice. batch - tag sentences in input file. Entities to be tagged are determing used POS tagging to find noun phrases.interactive - input one sentence at a time')
539
+ results = parser.parse_args()
540
+ config_file = results.config
541
+
app.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import streamlit as st
3
+ import torch
4
+ import string
5
+ from annotated_text import annotated_text
6
+
7
+ from flair.data import Sentence
8
+ from flair.models import SequenceTagger
9
+ from transformers import BertTokenizer, BertForMaskedLM
10
+ import BatchInference as bd
11
+ import batched_main_NER as ner
12
+ import aggregate_server_json as aggr
13
+ import json
14
+
15
+
16
+ DEFAULT_TOP_K = 20
17
+ SPECIFIC_TAG=":__entity__"
18
+
19
+
20
+
21
+ @st.cache(suppress_st_warning=True, allow_output_mutation=True)
22
+ def POS_get_model(model_name):
23
+ val = SequenceTagger.load(model_name) # Load the model
24
+ return val
25
+
26
+ def getPos(s: Sentence):
27
+ texts = []
28
+ labels = []
29
+ for t in s.tokens:
30
+ for label in t.annotation_layers.keys():
31
+ texts.append(t.text)
32
+ labels.append(t.get_labels(label)[0].value)
33
+ return texts, labels
34
+
35
+ def getDictFromPOS(texts, labels):
36
+ return [["dummy",t,l,"dummy","dummy" ] for t, l in zip(texts, labels)]
37
+
38
+ def decode(tokenizer, pred_idx, top_clean):
39
+ ignore_tokens = string.punctuation + '[PAD]'
40
+ tokens = []
41
+ for w in pred_idx:
42
+ token = ''.join(tokenizer.decode(w).split())
43
+ if token not in ignore_tokens:
44
+ tokens.append(token.replace('##', ''))
45
+ return '\n'.join(tokens[:top_clean])
46
+
47
+ def encode(tokenizer, text_sentence, add_special_tokens=True):
48
+ text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
49
+ # if <mask> is the last token, append a "." so that models dont predict punctuation.
50
+ if tokenizer.mask_token == text_sentence.split()[-1]:
51
+ text_sentence += ' .'
52
+
53
+ input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
54
+ mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
55
+ return input_ids, mask_idx
56
+
57
+ def get_all_predictions(text_sentence, top_clean=5):
58
+ # ========================= BERT =================================
59
+ input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
60
+ with torch.no_grad():
61
+ predict = bert_model(input_ids)[0]
62
+ bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k).indices.tolist(), top_clean)
63
+ return {'bert': bert}
64
+
65
+ def get_bert_prediction(input_text,top_k):
66
+ try:
67
+ input_text += ' <mask>'
68
+ res = get_all_predictions(input_text, top_clean=int(top_k))
69
+ return res
70
+ except Exception as error:
71
+ pass
72
+
73
+
74
+ def load_pos_model():
75
+ checkpoint = "flair/pos-english"
76
+ return POS_get_model(checkpoint)
77
+
78
+
79
+
80
+
81
+ def init_session_states():
82
+ if 'top_k' not in st.session_state:
83
+ st.session_state['top_k'] = 20
84
+ if 'pos_model' not in st.session_state:
85
+ st.session_state['pos_model'] = None
86
+ if 'bio_model' not in st.session_state:
87
+ st.session_state['bio_model'] = None
88
+ if 'phi_model' not in st.session_state:
89
+ st.session_state['phi_model'] = None
90
+ if 'ner_bio' not in st.session_state:
91
+ st.session_state['ner_bio'] = None
92
+ if 'ner_phi' not in st.session_state:
93
+ st.session_state['ner_phi'] = None
94
+ if 'aggr' not in st.session_state:
95
+ st.session_state['aggr'] = None
96
+
97
+
98
+
99
+ def get_pos_arr(input_text,display_area):
100
+ if (st.session_state['pos_model'] is None):
101
+ display_area.text("Loading model 3 of 3.Loading POS model...")
102
+ st.session_state['pos_model'] = load_pos_model()
103
+ s = Sentence(input_text)
104
+ st.session_state['pos_model'].predict(s)
105
+ texts, labels = getPos(s)
106
+ pos_results = getDictFromPOS(texts, labels)
107
+ return pos_results
108
+
109
+ def perform_inference(text,display_area):
110
+
111
+ if (st.session_state['bio_model'] is None):
112
+ display_area.text("Loading model 1 of 3. Bio model...")
113
+ st.session_state['bio_model'] = bd.BatchInference("bio/desc_a100_config.json",'ajitrajasekharan/biomedical',False,False,DEFAULT_TOP_K,True,True, "bio/","bio/a100_labels.txt",False)
114
+
115
+ if (st.session_state['phi_model'] is None):
116
+ display_area.text("Loading model 2 of 3. PHI model...")
117
+ st.session_state['phi_model'] = bd.BatchInference("bbc/desc_bbc_config.json",'bert-base-cased',False,False,DEFAULT_TOP_K,True,True, "bbc/","bbc/bbc_labels.txt",False)
118
+
119
+ #Load POS model if needed and gets POS tags
120
+ if (SPECIFIC_TAG not in text):
121
+ pos_arr = get_pos_arr(text,display_area)
122
+ else:
123
+ pos_arr = None
124
+
125
+ if (st.session_state['ner_bio'] is None):
126
+ display_area.text("Initializing BIO module...")
127
+ st.session_state['ner_bio'] = ner.UnsupNER("bio/ner_a100_config.json")
128
+
129
+ if (st.session_state['ner_phi'] is None):
130
+ display_area.text("Initializing PHI module...")
131
+ st.session_state['ner_phi'] = ner.UnsupNER("bbc/ner_bbc_config.json")
132
+
133
+ if (st.session_state['aggr'] is None):
134
+ display_area.text("Initializing Aggregation modeule...")
135
+ st.session_state['aggr'] = aggr.AggregateNER("./ensemble_config.json")
136
+
137
+
138
+
139
+ display_area.text("Getting results from BIO model...")
140
+ bio_descs = st.session_state['bio_model'].get_descriptors(text,pos_arr)
141
+ display_area.text("Getting results from PHI model...")
142
+ phi_results = st.session_state['phi_model'].get_descriptors(text,pos_arr)
143
+ display_area.text("Aggregating BIO & PHI results...")
144
+ bio_ner = st.session_state['ner_bio'].tag_sentence_service(text,bio_descs)
145
+ phi_ner = st.session_state['ner_phi'].tag_sentence_service(text,phi_results)
146
+
147
+ combined_arr = [json.loads(bio_ner),json.loads(phi_ner)]
148
+
149
+ aggregate_results = st.session_state['aggr'].fetch_all(text,combined_arr)
150
+ return aggregate_results
151
+
152
+
153
+ sent_arr = [
154
+ "Lou Gehrig who works for XCorp and lives in New York suffers from Parkinson's ",
155
+ "Parkinson who works for XCorp and lives in New York suffers from Lou Gehrig's",
156
+ "lou gehrig was diagnosed with Parkinson's ",
157
+ "A eGFR below 60 indicates chronic kidney disease",
158
+ "Overexpression of EGFR occurs across a wide range of different cancers",
159
+ "Stanford called",
160
+ "He was diagnosed with non small cell lung cancer",
161
+ "I met my girl friends at the pub ",
162
+ "I met my New York friends at the pub",
163
+ "I met my XCorp friends at the pub",
164
+ "I met my two friends at the pub",
165
+ "Bio-Techne's genomic tools include advanced tissue-based in-situ hybridization assays sold under the ACD brand as well as a portfolio of assays for prostate cancer diagnosis ",
166
+ "There are no treatment options specifically indicated for ACD and physicians must utilize agents approved for other dermatology conditions", "As ACD has been implicated in apoptosis-resistant glioblastoma (GBM), there is a high medical need for identifying novel ACD-inducing drugs ",
167
+ "Located in the heart of Dublin , in the family home of acclaimed writer Oscar Wilde , ACD provides the perfect backdrop to inspire Irish (and Irish-at-heart) students to excel in business and the arts",
168
+ "Patients treated with anticancer chemotherapy drugs ( ACD ) are vulnerable to infectious diseases due to immunosuppression and to the direct impact of ACD on their intestinal microbiota ",
169
+ "In the LASOR trial , increasing daily imatinib dose from 400 to 600mg induced MMR at 12 and 24 months in 25% and 36% of the patients, respectively, who had suboptimal cytogenetic responses ",
170
+ "The sky turned dark in advance of the storm that was coming from the east ",
171
+ "She loves to watch Sunday afternoon football with her family ",
172
+ "Paul Erdos died at 83 "
173
+ ]
174
+
175
+
176
+ sent_arr_masked = [
177
+ "Lou Gehrig:__entity__ who works for XCorp:__entity__ and lives in New:__entity__ York:__entity__ suffers from Parkinson's:__entity__ ",
178
+ "Parkinson:__entity__ who works for XCorp:__entity__ and lives in New:__entity__ York:__entity__ suffers from Lou Gehrig's:__entity__",
179
+ "lou:__entity__ gehrig:__entity__ was diagnosed with Parkinson's:__entity__ ",
180
+ "A eGFR:__entity__ below 60 indicates chronic kidney disease",
181
+ "Overexpression of EGFR:__entity__ occurs across a wide range of different cancers",
182
+ "Stanford:__entity__ called",
183
+ "He was diagnosed with non:__entity__ small:__entity__ cell:__entity__ lung:__entity__ cancer:__entity__",
184
+ "I met my girl:__entity__ friends at the pub ",
185
+ "I met my New:__entity__ York:__entity__ friends at the pub",
186
+ "I met my XCorp:__entity__ friends at the pub",
187
+ "I met my two:__entity__ friends at the pub",
188
+ "Bio-Techne's genomic tools include advanced tissue-based in-situ hybridization assays sold under the ACD:__entity__ brand as well as a portfolio of assays for prostate cancer diagnosis ",
189
+ "There are no treatment options specifically indicated for ACD:__entity__ and physicians must utilize agents approved for other dermatology conditions",
190
+ "As ACD:__entity__ has been implicated in apoptosis-resistant glioblastoma (GBM), there is a high medical need for identifying novel ACD-inducing drugs ",
191
+ "Located in the heart of Dublin , in the family home of acclaimed writer Oscar Wilde , ACD:__entity__ provides the perfect backdrop to inspire Irish (and Irish-at-heart) students to excel in business and the arts",
192
+ "Patients treated with anticancer chemotherapy drugs ( ACD:__entity__ ) are vulnerable to infectious diseases due to immunosuppression and to the direct impact of ACD on their intestinal microbiota ",
193
+ "In the LASOR:__entity__ trial:__entity__ , increasing daily imatinib dose from 400 to 600mg induced MMR at 12 and 24 months in 25% and 36% of the patients, respectively, who had suboptimal cytogenetic responses ",
194
+ "The sky turned dark:__entity__ in advance of the storm that was coming from the east ",
195
+ "She loves to watch Sunday afternoon football:__entity__ with her family ",
196
+ "Paul:__entity__ Erdos:__entity__ died at 83:__entity__ "
197
+ ]
198
+
199
+ def init_selectbox():
200
+ return st.selectbox(
201
+ 'Choose any of the sentences in pull-down below',
202
+ sent_arr,key='my_choice')
203
+
204
+
205
+ def on_text_change():
206
+ text = st.session_state.my_text
207
+ print("in callback: " + text)
208
+ perform_inference(text)
209
+
210
+ def main():
211
+ try:
212
+
213
+ init_session_states()
214
+
215
+ st.markdown("<h3 style='text-align: center;'>NER using pretrained models with <a href='https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html'>no fine tuning</a></h3>", unsafe_allow_html=True)
216
+ #st.markdown("""
217
+ #<h3 style="font-size:16px; color: #ff0000; text-align: center"><b>App under construction... (not in working condition yet)</b></h3>
218
+ #""", unsafe_allow_html=True)
219
+
220
+
221
+ st.markdown("""
222
+ <p style="text-align:center;"><img src="https://ajitrajasekharan.github.io/images/1.png" width="700"></p>
223
+ <br/>
224
+ <br/>
225
+ """, unsafe_allow_html=True)
226
+
227
+ st.write("This app uses 3 models. Two Pretrained Bert models (**no fine tuning**) and a POS tagger")
228
+
229
+
230
+ with st.form('my_form'):
231
+ selected_sentence = init_selectbox()
232
+ text_input = st.text_area(label='Type any sentence below',value="")
233
+ submit_button = st.form_submit_button('Submit')
234
+ input_status_area = st.empty()
235
+ display_area = st.empty()
236
+ if submit_button:
237
+ start = time.time()
238
+ if (len(text_input) == 0):
239
+ text_input = sent_arr_masked[sent_arr.index(selected_sentence)]
240
+ input_status_area.text("Input sentence: " + text_input)
241
+ results = perform_inference(text_input,display_area)
242
+ display_area.empty()
243
+ with display_area.container():
244
+ st.text(f"prediction took {time.time() - start:.2f}s")
245
+ st.json(results)
246
+
247
+
248
+
249
+
250
+
251
+ #input_text = st.text_area(
252
+ # label="Type any sentence",
253
+ # on_change=on_text_change,key='my_text'
254
+ # )
255
+
256
+ st.markdown("""
257
+ <small style="font-size:16px; color: #7f7f7f; text-align: left"><br/><br/>Models used: <br/>(1) <a href='https://huggingface.co/ajitrajasekharan/biomedical' target='_blank'>Biomedical model</a> pretrained on Pubmed,Clinical trials and BookCorpus subset.<br/>(2) Bert-base-cased (for PHI entities - Person/location/organization etc.)<br/>(3) Flair POS tagger</small>
258
+ #""", unsafe_allow_html=True)
259
+ st.markdown("""
260
+ <h3 style="font-size:16px; color: #9f9f9f; text-align: center"><b> <a href='https://huggingface.co/spaces/ajitrajasekharan/Qualitative-pretrained-model-evaluation' target='_blank'>App link to examine pretrained models</a> used to perform NER without fine tuning</b></h3>
261
+ """, unsafe_allow_html=True)
262
+ st.markdown("""
263
+ <h3 style="font-size:16px; color: #9f9f9f; text-align: center">Github <a href='http://github.com/ajitrajasekharan/unsupervised_NER' target='_blank'>link to same working code </a>(without UI) as separate microservices</h3>
264
+ """, unsafe_allow_html=True)
265
+
266
+ except Exception as e:
267
+ print("Some error occurred in main")
268
+ st.exception(e)
269
+
270
+ if __name__ == "__main__":
271
+ main()
batched_main_NER.py ADDED
@@ -0,0 +1,905 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import config_utils as cf
3
+ import requests
4
+ import sys
5
+ import urllib.parse
6
+ import numpy as np
7
+ from collections import OrderedDict
8
+ import argparse
9
+ from common import *
10
+ import json
11
+
12
+ #WORD_POS = 1
13
+ #TAG_POS = 2
14
+ #MASK_TAG = "__entity__"
15
+ DEFAULT_CONFIG = "./config.json"
16
+ DISPATCH_MASK_TAG = "entity"
17
+ DESC_HEAD = "PIVOT_DESCRIPTORS:"
18
+ #TYPE2_AMB = "AMB2-"
19
+ TYPE2_AMB = ""
20
+ DUMMY_DESCS=10
21
+ DEFAULT_ENTITY_MAP = "entity_types_consolidated.txt"
22
+
23
+ #RESET_POS_TAG='RESET'
24
+ SPECIFIC_TAG=":__entity__"
25
+
26
+
27
+ def softmax(x):
28
+ """Compute softmax values for each sets of scores in x."""
29
+ return np.exp(x) / np.sum(np.exp(x), axis=0)
30
+
31
+
32
+ #noun_tags = ['NFP','JJ','NN','FW','NNS','NNPS','JJS','JJR','NNP','POS','CD']
33
+ #cap_tags = ['NFP','JJ','NN','FW','NNS','NNPS','JJS','JJR','NNP','PRP']
34
+
35
+ def read_common_descs(file_name):
36
+ common_descs = {}
37
+ with open(file_name) as fp:
38
+ for line in fp:
39
+ common_descs[line.strip()] = 1
40
+ print("Common descs for filtering read:",len(common_descs))
41
+ return common_descs
42
+
43
+ def read_entity_map(file_name):
44
+ emap = {}
45
+ with open(file_name) as fp:
46
+ for line in fp:
47
+ line = line.rstrip('\n')
48
+ entities = line.split()
49
+ if (len(entities) == 1):
50
+ assert(entities[0] not in emap)
51
+ emap[entities[0]] = entities[0]
52
+ else:
53
+ assert(len(entities) == 2)
54
+ entity_arr = entities[1].split('/')
55
+ if (entities[0] not in emap):
56
+ emap[entities[0]] = entities[0]
57
+ for entity in entity_arr:
58
+ assert(entity not in emap)
59
+ emap[entity] = entities[0]
60
+ print("Entity map:",len(emap))
61
+ return emap
62
+
63
+ class UnsupNER:
64
+ def __init__(self,config_file):
65
+ print("NER service handler started")
66
+ base_path = cf.read_config(config_file)["BASE_PATH"] if ("BASE_PATH" in cf.read_config(config_file)) else "./"
67
+ self.pos_server_url = cf.read_config(config_file)["POS_SERVER_URL"]
68
+ self.desc_server_url = cf.read_config(config_file)["DESC_SERVER_URL"]
69
+ self.entity_server_url = cf.read_config(config_file)["ENTITY_SERVER_URL"]
70
+ self.common_descs = read_common_descs(cf.read_config(config_file)["COMMON_DESCS_FILE"])
71
+ self.entity_map = read_entity_map(cf.read_config(config_file)["EMAP_FILE"])
72
+ self.rfp = open(base_path + "log_results.txt","a")
73
+ self.dfp = open(base_path + "log_debug.txt","a")
74
+ self.algo_ci_tag_fp = open(base_path + "algorthimic_ci_tags.txt","a")
75
+ print(self.pos_server_url)
76
+ print(self.desc_server_url)
77
+ print(self.entity_server_url)
78
+ np.set_printoptions(suppress=True) #this suppresses exponential representation when np is used to round
79
+ if (cf.read_config(config_file)["SUPPRESS_UNTAGGED"] == "1"):
80
+ self.suppress_untagged = True
81
+ else:
82
+ self.suppress_untagged = False #This is disabled in full debug text mode
83
+
84
+
85
+ #This is bad hack for prototyping - parsing from text output as opposed to json
86
+ def extract_POS(self,text):
87
+ arr = text.split('\n')
88
+ if (len(arr) > 0):
89
+ start_pos = 0
90
+ for i,line in enumerate(arr):
91
+ if (len(line) > 0):
92
+ start_pos += 1
93
+ continue
94
+ else:
95
+ break
96
+ #print(arr[start_pos:])
97
+ terms_arr = []
98
+ for i,line in enumerate(arr[start_pos:]):
99
+ terms = line.split('\t')
100
+ if (len(terms) == 5):
101
+ #print(terms)
102
+ terms_arr.append(terms)
103
+ return terms_arr
104
+
105
+ def normalize_casing(self,sent):
106
+ sent_arr = sent.split()
107
+ ret_sent_arr = []
108
+ for i,word in enumerate(sent_arr):
109
+ if (len(word) > 1):
110
+ norm_word = word[0] + word[1:].lower()
111
+ else:
112
+ norm_word = word[0]
113
+ ret_sent_arr.append(norm_word)
114
+ return ' '.join(ret_sent_arr)
115
+
116
+ #Full sentence tag call also generates json output.
117
+ def tag_sentence_service(self,text,desc_obj):
118
+ ret_str = self.tag_sentence(text,self.rfp,self.dfp,True,desc_obj)
119
+ return ret_str
120
+
121
+ def dictify_ner_response(self,ner_str):
122
+ arr = ner_str.split('\n')
123
+ ret_dict = OrderedDict()
124
+ count = 1
125
+ ref_indices_arr = []
126
+ for line in arr:
127
+ terms = line.split()
128
+ if (len(terms) == 2):
129
+ ret_dict[count] = {"term":terms[0],"e":terms[1]}
130
+ if (terms[1] != "O" and terms[1].startswith("B_")):
131
+ ref_indices_arr.append(count)
132
+ count += 1
133
+ elif (len(terms) == 1):
134
+ ret_dict[count] = {"term":"empty","e":terms[0]}
135
+ if (terms[0] != "O" and terms[0].startswith("B_")):
136
+ ref_indices_arr.append(count)
137
+ count += 1
138
+ if (len(ret_dict) > 3): #algorithmic harvesting of CI labels for human verification and adding to bootstrap list
139
+ self.algo_ci_tag_fp.write("SENT:" + ner_str.replace('\n',' ') + "\n")
140
+ out = terms[0].replace('[',' ').replace(']','').split()[-1]
141
+ out = '_'.join(out.split('_')[1:]) if out.startswith("B_") else out
142
+ print(out)
143
+ self.algo_ci_tag_fp.write(ret_dict[count-2]["term"] + " " + out + "\n")
144
+ self.algo_ci_tag_fp.flush()
145
+ else:
146
+ assert(len(terms) == 0) #If not empty something is not right
147
+ return ret_dict,ref_indices_arr
148
+
149
+ def blank_entity_sentence(self,sent,dfp):
150
+ value = True if sent.endswith(" :__entity__\n") else False
151
+ if (value == True):
152
+ print("\n\n**************** Skipping CI prediction in pooling for sent:",sent)
153
+ dfp.write("\n\n**************** Skipping CI prediction in pooling for sent:" + sent + "\n")
154
+ return value
155
+
156
+ def pool_confidences(self,ci_entities,ci_confidences,ci_subtypes,cs_entities,cs_confidences,cs_subtypes,debug_str_arr,sent,dfp):
157
+ main_classes = {}
158
+ assert(len(cs_entities) == len(cs_confidences))
159
+ assert(len(cs_subtypes) == len(cs_entities))
160
+ assert(len(ci_entities) == len(ci_confidences))
161
+ assert(len(ci_subtypes) == len(ci_entities))
162
+ #Pool entity classes across CI and CS
163
+ is_blank_statement = self.blank_entity_sentence(sent,dfp) #Do not pool CI confidences of the sentences of the form " is a entity". These sentences are sent for purely algo harvesting of CS terms. CI predictions will add noise.
164
+ if (not is_blank_statement): #Do not pool CI confidences of the sentences of the form " is a entity". These sentences are sent for purely algo harvesting of CS terms. CI predictions will add noise.
165
+ for e,c in zip(ci_entities,ci_confidences):
166
+ e_base = e.split('[')[0]
167
+ main_classes[e_base] = float(c)
168
+ for e,c in zip(cs_entities,cs_confidences):
169
+ e_base = e.split('[')[0]
170
+ if (e_base in main_classes):
171
+ main_classes[e_base] += float(c)
172
+ else:
173
+ main_classes[e_base] = float(c)
174
+ final_sorted_d = OrderedDict(sorted(main_classes.items(), key=lambda kv: kv[1], reverse=True))
175
+ main_dist = self.convert_positive_nums_to_dist(final_sorted_d)
176
+ main_classes_arr = list(final_sorted_d.keys())
177
+ #print("\nIn pooling confidences")
178
+ #print(main_classes_arr)
179
+ #print(main_dist)
180
+ #Pool subtypes across CI and CS for a particular entity class
181
+ subtype_factors = {}
182
+ for e_class in final_sorted_d:
183
+ if e_class in cs_subtypes:
184
+ stypes = cs_subtypes[e_class]
185
+ if (e_class not in subtype_factors):
186
+ subtype_factors[e_class] = {}
187
+ for st in stypes:
188
+ if (st in subtype_factors[e_class]):
189
+ subtype_factors[e_class][st] += stypes[st]
190
+ else:
191
+ subtype_factors[e_class][st] = stypes[st]
192
+ if (is_blank_statement):
193
+ continue
194
+ if e_class in ci_subtypes:
195
+ stypes = ci_subtypes[e_class]
196
+ if (e_class not in subtype_factors):
197
+ subtype_factors[e_class] = {}
198
+ for st in stypes:
199
+ if (st in subtype_factors[e_class]):
200
+ subtype_factors[e_class][st] += stypes[st]
201
+ else:
202
+ subtype_factors[e_class][st] = stypes[st]
203
+ sorted_subtype_factors = {}
204
+ for e_class in subtype_factors:
205
+ stypes = subtype_factors[e_class]
206
+ final_sorted_d = OrderedDict(sorted(stypes.items(), key=lambda kv: kv[1], reverse=True))
207
+ stypes_dist = self.convert_positive_nums_to_dist(final_sorted_d)
208
+ stypes_class_arr = list(final_sorted_d.keys())
209
+ sorted_subtype_factors[e_class] = {"stypes":stypes_class_arr,"dist":stypes_dist}
210
+ pooled_results = OrderedDict()
211
+ assert(len(main_classes_arr) == len(main_dist))
212
+ d_str_arr = []
213
+ d_str_arr.append("\n***CONSOLIDATED ENTITY:")
214
+ for e,c in zip(main_classes_arr,main_dist):
215
+ pooled_results[e] = {"e":e,"confidence":c}
216
+ d_str_arr.append(e + " " + str(c))
217
+ stypes_dict = sorted_subtype_factors[e]
218
+ pooled_st = OrderedDict()
219
+ for st,sd in zip(stypes_dict["stypes"],stypes_dict["dist"]):
220
+ pooled_st[st] = sd
221
+ pooled_results[e]["stypes"] = pooled_st
222
+ debug_str_arr.append(' '.join(d_str_arr))
223
+ print(' '.join(d_str_arr))
224
+ return pooled_results
225
+
226
+
227
+
228
+
229
+
230
+
231
+
232
+
233
+
234
+ def init_entity_info(self,entity_info_dict,index):
235
+ curr_term_dict = OrderedDict()
236
+ entity_info_dict[index] = curr_term_dict
237
+ curr_term_dict["ci"] = OrderedDict()
238
+ curr_term_dict["ci"]["entities"] = []
239
+ curr_term_dict["ci"]["descs"] = []
240
+ curr_term_dict["cs"] = OrderedDict()
241
+ curr_term_dict["cs"]["entities"] = []
242
+ curr_term_dict["cs"]["descs"] = []
243
+
244
+
245
+
246
+
247
+ #This now does specific tagging if there is a __entity__ in sentence; else does full tagging. TBD.
248
+ #TBD. Make response params same regardlesss of output format. Now it is different
249
+ def tag_sentence(self,sent,rfp,dfp,json_output,desc_obj):
250
+ print("Input: ", sent)
251
+ dfp.write("\n\n++++-------------------------------\n")
252
+ dfp.write("NER_INPUT: " + sent + "\n")
253
+ debug_str_arr = []
254
+ entity_info_dict = OrderedDict()
255
+ #url = self.desc_server_url + sent.replace('"','\'')
256
+ #r = self.dispatch_request(url)
257
+ #if (r is None):
258
+ # print("Empty response. Desc server is probably down: ",self.desc_server_url)
259
+ # return json.loads("[]")
260
+ #main_obj = json.loads(r.text)
261
+ main_obj = desc_obj
262
+ #print(json.dumps(main_obj,indent=4))
263
+ #Find CI predictions for ALL masked predictios in sentence
264
+ ci_predictions,orig_ci_entities = self.find_ci_entities(main_obj,debug_str_arr,entity_info_dict) #ci_entities is the same info as ci_predictions except packed differently for output
265
+ #Find CS predictions for ALL masked predictios in sentence. Use the CI predictions from previous step to
266
+ #pool
267
+ detected_entities_arr,ner_str,full_pooled_results,orig_cs_entities = self.find_cs_entities(sent,main_obj,rfp,dfp,debug_str_arr,ci_predictions,entity_info_dict)
268
+ assert(len(detected_entities_arr) == len(entity_info_dict))
269
+ print("--------")
270
+ if (json_output):
271
+ if (len(detected_entities_arr) != len(entity_info_dict)):
272
+ if (len(entity_info_dict) == 0):
273
+ self.init_entity_info(entity_info_dict,index)
274
+ entity_info_dict[1]["cs"]["entities"].append([{"e":"O","confidence":1}])
275
+ entity_info_dict[1]["ci"]["entities"].append([{"e":"O","confidence":1}])
276
+ ret_dict,ref_indices_arr = self.dictify_ner_response(ner_str) #Convert ner string to a dictionary for json output
277
+ assert(len(ref_indices_arr) == len(detected_entities_arr))
278
+ assert(len(entity_info_dict) == len(detected_entities_arr))
279
+ cs_aux_dict = OrderedDict()
280
+ ci_aux_dict = OrderedDict()
281
+ cs_aux_orig_entities = OrderedDict()
282
+ ci_aux_orig_entities = OrderedDict()
283
+ pooled_pred_dict = OrderedDict()
284
+ count = 0
285
+ assert(len(full_pooled_results) == len(detected_entities_arr))
286
+ assert(len(full_pooled_results) == len(orig_cs_entities))
287
+ assert(len(full_pooled_results) == len(orig_ci_entities))
288
+ for e,c,p,o,i in zip(detected_entities_arr,entity_info_dict,full_pooled_results,orig_cs_entities,orig_ci_entities):
289
+ val = entity_info_dict[c]
290
+ #cs_aux_dict[ref_indices_arr[count]] = {"e":e,"cs_distribution":val["cs"]["entities"],"cs_descs":val["cs"]["descs"]}
291
+ pooled_pred_dict[ref_indices_arr[count]] = {"e": e, "cs_distribution": list(p.values())}
292
+ cs_aux_dict[ref_indices_arr[count]] = {"e":e,"cs_descs":val["cs"]["descs"]}
293
+ #ci_aux_dict[ref_indices_arr[count]] = {"ci_distribution":val["ci"]["entities"],"ci_descs":val["ci"]["descs"]}
294
+ ci_aux_dict[ref_indices_arr[count]] = {"ci_descs":val["ci"]["descs"]}
295
+ cs_aux_orig_entities[ref_indices_arr[count]] = {"e":e,"cs_distribution": o}
296
+ ci_aux_orig_entities[ref_indices_arr[count]] = {"e":e,"cs_distribution": i}
297
+ count += 1
298
+ #print(ret_dict)
299
+ #print(aux_dict)
300
+ final_ret_dict = {"total_terms_count":len(ret_dict),"detected_entity_phrases_count":len(detected_entities_arr),"ner":ret_dict,"entity_distribution":pooled_pred_dict,"cs_prediction_details":cs_aux_dict,"ci_prediction_details":ci_aux_dict,"orig_cs_prediction_details":cs_aux_orig_entities,"orig_ci_prediction_details":ci_aux_orig_entities,"debug":debug_str_arr}
301
+ json_str = json.dumps(final_ret_dict,indent = 4)
302
+ #print (json_str)
303
+ #with open("single_debug.txt","w") as fp:
304
+ # fp.write(json_str)
305
+
306
+ dfp.write('\n'.join(debug_str_arr))
307
+ dfp.write("\n\nEND-------------------------------\n")
308
+ dfp.flush()
309
+ return json_str
310
+ else:
311
+ print(detected_entities_arr)
312
+ debug_str_arr.append("NER_FINAL_RESULTS: " + ' '.join(detected_entities_arr))
313
+ print("--------")
314
+ dfp.write('\n'.join(debug_str_arr))
315
+ dfp.write("\n\nEND-------------------------------\n")
316
+ dfp.flush()
317
+ return detected_entities_arr,span_arr,terms_arr,ner_str,debug_str_arr
318
+
319
+ def masked_word_first_letter_capitalize(self,entity):
320
+ arr = entity.split()
321
+ ret_arr = []
322
+ for term in arr:
323
+ if (len(term) > 1 and term[0].islower() and term[1].islower()):
324
+ ret_arr.append(term[0].upper() + term[1:])
325
+ else:
326
+ ret_arr.append(term)
327
+ return ' '.join(ret_arr)
328
+
329
+
330
+ def gen_single_phrase_sentences(self,terms_arr,masked_sent_arr,span_arr,rfp,dfp):
331
+ sentence_template = "%s is a entity"
332
+ print(span_arr)
333
+ sentences = []
334
+ singleton_spans_arr = []
335
+ run_index = 0
336
+ entity = ""
337
+ singleton_span = []
338
+ while (run_index < len(span_arr)):
339
+ if (span_arr[run_index] == 1):
340
+ while (run_index < len(span_arr)):
341
+ if (span_arr[run_index] == 1):
342
+ #print(terms_arr[run_index][WORD_POS],end=' ')
343
+ if (len(entity) == 0):
344
+ entity = terms_arr[run_index][WORD_POS]
345
+ else:
346
+ entity = entity + " " + terms_arr[run_index][WORD_POS]
347
+ singleton_span.append(1)
348
+ run_index += 1
349
+ else:
350
+ break
351
+ #print()
352
+ for i in sentence_template.split():
353
+ if (i != "%s"):
354
+ singleton_span.append(0)
355
+ entity = self.masked_word_first_letter_capitalize(entity)
356
+ sentence = sentence_template % entity
357
+ sentences.append(sentence)
358
+ singleton_spans_arr.append(singleton_span)
359
+ print(sentence)
360
+ print(singleton_span)
361
+ entity = ""
362
+ singleton_span = []
363
+ else:
364
+ run_index += 1
365
+ return sentences,singleton_spans_arr
366
+
367
+
368
+ def find_ci_entities(self,main_obj,debug_str_arr,entity_info_dict):
369
+ ci_predictions = []
370
+ orig_ci_confidences = []
371
+ term_index = 1
372
+ batch_obj = main_obj["descs_and_entities"]
373
+ for key in batch_obj:
374
+ masked_sent = batch_obj[key]["ci_prediction"]["sentence"]
375
+ print("\n**CI: ", masked_sent)
376
+ debug_str_arr.append(masked_sent)
377
+ #entity_info_dict["masked_sent"].append(masked_sent)
378
+ inp_arr = batch_obj[key]["ci_prediction"]["descs"]
379
+ descs = self.get_descriptors_for_masked_position(inp_arr)
380
+ self.init_entity_info(entity_info_dict,term_index)
381
+ entities,confidences,subtypes = self.get_entities_for_masked_position(inp_arr,descs,debug_str_arr,entity_info_dict[term_index]["ci"])
382
+ ci_predictions.append({"entities":entities,"confidences":confidences,"subtypes":subtypes})
383
+ orig_ci_confidences.append(self.pack_confidences(entities,confidences)) #this is sent for ensemble server to detect cross predictions. CS predicitons are more reflective of cross over than consolidated predictions, since CI may overwhelm CS
384
+ term_index += 1
385
+ return ci_predictions,orig_ci_confidences
386
+
387
+
388
+ def pack_confidences(self,cs_entities,cs_confidences):
389
+ assert(len(cs_entities) == len(cs_confidences))
390
+ orig_cs_arr = []
391
+ for e,c in zip(cs_entities,cs_confidences):
392
+ print(e,c)
393
+ e_split = e.split('[')
394
+ e_main = e_split[0]
395
+ if (len(e_split) > 1):
396
+ e_sub = e_split[1].split(',')[0].rstrip(']')
397
+ if (e_main != e_sub):
398
+ e = e_main + '[' + e_sub + ']'
399
+ else:
400
+ e = e_main
401
+ else:
402
+ e = e_main
403
+ orig_cs_arr.append({"e":e,"confidence":c})
404
+ return orig_cs_arr
405
+
406
+
407
+ #We have multiple masked versions of a single sentence. Tag each one of them
408
+ #and create a complete tagged version for a sentence
409
+ def find_cs_entities(self,sent,main_obj,rfp,dfp,debug_str_arr,ci_predictions,entity_info_dict):
410
+ #print(sent)
411
+ batch_obj = main_obj["descs_and_entities"]
412
+ dfp.write(sent + "\n")
413
+ term_index = 1
414
+ detected_entities_arr = []
415
+ full_pooled_results = []
416
+ orig_cs_confidences = []
417
+ for index,key in enumerate(batch_obj):
418
+ position_info = batch_obj[key]["cs_prediction"]["descs"]
419
+ ci_entities = ci_predictions[index]["entities"]
420
+ ci_confidences = ci_predictions[index]["confidences"]
421
+ ci_subtypes = ci_predictions[index]["subtypes"]
422
+ debug_str_arr.append("\n++++++ nth Masked term : " + str(key))
423
+ #dfp.write(key + "\n")
424
+ masked_sent = batch_obj[key]["cs_prediction"]["sentence"]
425
+ print("\n**CS: ",masked_sent)
426
+ descs = self.get_descriptors_for_masked_position(position_info)
427
+ #dfp.write(str(descs) + "\n")
428
+ if (len(descs) > 0):
429
+ cs_entities,cs_confidences,cs_subtypes = self.get_entities_for_masked_position(position_info,descs,debug_str_arr,entity_info_dict[term_index]["cs"])
430
+ else:
431
+ cs_entities = []
432
+ cs_confidences = []
433
+ cs_subtypes = []
434
+ #dfp.write(str(cs_entities) + "\n")
435
+ pooled_results = self.pool_confidences(ci_entities,ci_confidences,ci_subtypes,cs_entities,cs_confidences,cs_subtypes,debug_str_arr,sent,dfp)
436
+ self.fill_detected_entities(detected_entities_arr,pooled_results) #just picks the top prediction
437
+ full_pooled_results.append(pooled_results)
438
+ orig_cs_confidences.append(self.pack_confidences(cs_entities,cs_confidences)) #this is sent for ensemble server to detect cross predictions. CS predicitons are more reflective of cross over than consolidated predictions, since CI may overwhelm CS
439
+ #self.old_resolve_entities(i,singleton_entities,detected_entities_arr) #This decides how to pick entities given CI and CS predictions
440
+ term_index += 1
441
+ #out of the full loop over sentences. Now create NER sentence
442
+ terms_arr = main_obj["terms_arr"]
443
+ span_arr = main_obj["span_arr"]
444
+ ner_str = self.emit_sentence_entities(sent,terms_arr,detected_entities_arr,span_arr,rfp) #just outputs results in NER Conll format
445
+ dfp.flush()
446
+ return detected_entities_arr,ner_str,full_pooled_results,orig_cs_confidences
447
+
448
+
449
+ def fill_detected_entities(self,detected_entities_arr,entities):
450
+ if (len(entities) > 0):
451
+ top_e_class = next(iter(entities))
452
+ top_subtype = next(iter(entities[top_e_class]["stypes"]))
453
+ if (top_e_class != top_subtype):
454
+ top_prediction = top_e_class + "[" + top_subtype + "]"
455
+ else:
456
+ top_prediction = top_e_class
457
+ detected_entities_arr.append(top_prediction)
458
+ else:
459
+ detected_entities_arr.append("OTHER")
460
+
461
+
462
+ def fill_detected_entities_old(self,detected_entities_arr,entities,pan_arr):
463
+ entities_dict = {}
464
+ count = 1
465
+ for i in entities:
466
+ cand = i.split("-")
467
+ for j in cand:
468
+ terms = j.split("/")
469
+ for k in terms:
470
+ if (k not in entities_dict):
471
+ entities_dict[k] = 1.0/count
472
+ else:
473
+ entities_dict[k] += 1.0/count
474
+ count += 1
475
+ final_sorted_d = OrderedDict(sorted(entities_dict.items(), key=lambda kv: kv[1], reverse=True))
476
+ first = "OTHER"
477
+ for first in final_sorted_d:
478
+ break
479
+ detected_entities_arr.append(first)
480
+
481
+ #Contextual entity is picked as first candidate before context independent candidate
482
+ def old_resolve_entities(self,index,singleton_entities,detected_entities_arr):
483
+ if (singleton_entities[index].split('[')[0] != detected_entities_arr[index].split('[')[0]):
484
+ if (singleton_entities[index].split('[')[0] != "OTHER" and detected_entities_arr[index].split('[')[0] != "OTHER"):
485
+ detected_entities_arr[index] = detected_entities_arr[index] + "/" + singleton_entities[index]
486
+ elif (detected_entities_arr[index].split('[')[0] == "OTHER"):
487
+ detected_entities_arr[index] = singleton_entities[index]
488
+ else:
489
+ pass
490
+ else:
491
+ #this is the case when both CI and CS entity type match. Since the subtypes are already ordered, just merge(CS/CI,CS/CI...) the two picking unique subtypes
492
+ main_entity = detected_entities_arr[index].split('[')[0]
493
+ cs_arr = detected_entities_arr[index].split('[')[1].rstrip(']').split(',')
494
+ ci_arr = singleton_entities[index].split('[')[1].rstrip(']').split(',')
495
+ cs_arr_len = len(cs_arr)
496
+ ci_arr_len = len(ci_arr)
497
+ max_len = ci_arr_len if ci_arr_len > cs_arr_len else cs_arr_len
498
+ merged_unique_subtype_dict = OrderedDict()
499
+ for i in range(cs_arr_len):
500
+ if (i < cs_arr_len and cs_arr[i] not in merged_unique_subtype_dict):
501
+ merged_unique_subtype_dict[cs_arr[i]] = 1
502
+ if (i < ci_arr_len and ci_arr[i] not in merged_unique_subtype_dict):
503
+ merged_unique_subtype_dict[ci_arr[i]] = 1
504
+ new_subtypes_str = ','.join(list(merged_unique_subtype_dict.keys()))
505
+ detected_entities_arr[index] = main_entity + '[' + new_subtypes_str + ']'
506
+
507
+
508
+
509
+
510
+
511
+
512
+ def emit_sentence_entities(self,sent,terms_arr,detected_entities_arr,span_arr,rfp):
513
+ print("Final result")
514
+ ret_str = ""
515
+ for i,term in enumerate(terms_arr):
516
+ print(term,' ',end='')
517
+ print()
518
+ sent_arr = sent.split()
519
+ assert(len(terms_arr) == len(span_arr))
520
+ entity_index = 0
521
+ i = 0
522
+ in_span = False
523
+ while (i < len(span_arr)):
524
+ if (span_arr[i] == 0):
525
+ tag = "O"
526
+ if (in_span):
527
+ in_span = False
528
+ entity_index += 1
529
+ else:
530
+ if (in_span):
531
+ tag = "I_" + detected_entities_arr[entity_index]
532
+ else:
533
+ in_span = True
534
+ tag = "B_" + detected_entities_arr[entity_index]
535
+ rfp.write(terms_arr[i] + ' ' + tag + "\n")
536
+ ret_str = ret_str + terms_arr[i] + ' ' + tag + "\n"
537
+ print(tag + ' ',end='')
538
+ i += 1
539
+ print()
540
+ rfp.write("\n")
541
+ ret_str += "\n"
542
+ rfp.flush()
543
+ return ret_str
544
+
545
+
546
+
547
+
548
+
549
+ def get_descriptors_for_masked_position(self,inp_arr):
550
+ desc_arr = []
551
+ for i in range(len(inp_arr)):
552
+ desc_arr.append(inp_arr[i]["desc"])
553
+ desc_arr.append(inp_arr[i]["v"])
554
+ return desc_arr
555
+
556
+ def dispatch_request(self,url):
557
+ max_retries = 10
558
+ attempts = 0
559
+ while True:
560
+ try:
561
+ r = requests.get(url,timeout=1000)
562
+ if (r.status_code == 200):
563
+ return r
564
+ except:
565
+ print("Request:", url, " failed. Retrying...")
566
+ attempts += 1
567
+ if (attempts >= max_retries):
568
+ print("Request:", url, " failed")
569
+ break
570
+
571
+ def convert_positive_nums_to_dist(self,final_sorted_d):
572
+ factors = list(final_sorted_d.values()) #convert dict values to an array
573
+ factors = list(map(float, factors))
574
+ total = float(sum(factors))
575
+ if (total == 0):
576
+ total = 1
577
+ factors[0] = 1 #just make the sum 100%. This a boundary case for numbers for instance
578
+ factors = np.array(factors)
579
+ #factors = softmax(factors)
580
+ factors = factors/total
581
+ factors = np.round(factors,4)
582
+ return factors
583
+
584
+ def get_desc_weights_total(self,count,desc_weights):
585
+ i = 0
586
+ total = 0
587
+ while (i < count):
588
+ total += float(desc_weights[i+1])
589
+ i += 2
590
+ total = 1 if total == 0 else total
591
+ return total
592
+
593
+
594
+ def aggregate_entities(self,entities,desc_weights,debug_str_arr,entity_info_dict_entities):
595
+ ''' Given a masked position, whose entity we are trying to determine,
596
+ First get descriptors for that postion 2*N array [desc1,score1,desc2,score2,...]
597
+ Then for each descriptor, get entity predictions which is an array 2*N of the form [e1,score1,e2,score2,...] where e1 could be DRUG/DISEASE and score1 is 10/8 etc.
598
+ In this function we aggregate each unique entity prediction (e.g. DISEASE) by summing up its weighted scores across all N predictions.
599
+ The result factor array is normalized to create a probability distribution
600
+ '''
601
+ count = len(entities)
602
+ assert(count %2 == 0)
603
+ aggregate_entities = {}
604
+ i = 0
605
+ subtypes = {}
606
+ while (i < count):
607
+ #entities[i] contains entity names and entities[i+] contains counts. Example PROTEIN/GENE/PERSON is i and 10/4/7 is i+1
608
+ curr_counts = entities[i+1].split('/') #this is one of the N predictions - this single prediction is itself a list of entities
609
+ trunc_e,trunc_counts = self.map_entities(entities[i].split('/'),curr_counts,subtypes) # Aggregate the subtype entities for this predictions. Subtypes aggregation is **across** the N predictions
610
+ #Also trunc_e contains the consolidated entity names.
611
+ assert(len(trunc_e) <= len(curr_counts)) # can be less if untagged is skipped
612
+ assert(len(trunc_e) == len(trunc_counts))
613
+ trunc_counts = softmax(trunc_counts) #this normalization is done to reduce the effect of absolute count of certain labeled entities, while aggregating the entity vectors across descriptors
614
+ curr_counts_sum = sum(map(int,trunc_counts)) #Using truncated count
615
+ curr_counts_sum = 1 if curr_counts_sum == 0 else curr_counts_sum
616
+ for j in range(len(trunc_e)): #this is iterating through the current instance of all *consolidated* tagged entity predictons (that is except UNTAGGED_ENTITY)
617
+ if (self.skip_untagged(trunc_e[j])):
618
+ continue
619
+ if (trunc_e[j] not in aggregate_entities):
620
+ aggregate_entities[trunc_e[j]] = (float(trunc_counts[j]))*float(desc_weights[i+1])
621
+ #aggregate_entities[trunc_e[j]] = (float(trunc_counts[j])/curr_counts_sum)*float(desc_weights[i+1])
622
+ #aggregate_entities[trunc_e[j]] = float(desc_weights[i+1])
623
+ else:
624
+ aggregate_entities[trunc_e[j]] += (float(trunc_counts[j]))*float(desc_weights[i+1])
625
+ #aggregate_entities[trunc_e[j]] += (float(trunc_counts[j])/curr_counts_sum)*float(desc_weights[i+1])
626
+ #aggregate_entities[trunc_e[j]] += float(desc_weights[i+1])
627
+ i += 2
628
+ final_sorted_d = OrderedDict(sorted(aggregate_entities.items(), key=lambda kv: kv[1], reverse=True))
629
+ if (len(final_sorted_d) == 0): #Case where all terms are tagged OTHER
630
+ final_sorted_d = {"OTHER":1}
631
+ subtypes["OTHER"] = {"OTHER":1}
632
+ factors = self.convert_positive_nums_to_dist(final_sorted_d)
633
+ ret_entities = list(final_sorted_d.keys())
634
+ confidences = factors.tolist()
635
+ print(ret_entities)
636
+ sorted_subtypes = self.sort_subtypes(subtypes)
637
+ ret_entities = self.update_entities_with_subtypes(ret_entities,sorted_subtypes)
638
+ print(ret_entities)
639
+ debug_str_arr.append(" ")
640
+ debug_str_arr.append(' '.join(ret_entities))
641
+ print(confidences)
642
+ assert(len(confidences) == len(ret_entities))
643
+ arr = []
644
+ for e,c in zip(ret_entities,confidences):
645
+ arr.append({"e":e,"confidence":c})
646
+ entity_info_dict_entities.append(arr)
647
+ debug_str_arr.append(' '.join([str(x) for x in confidences]))
648
+ debug_str_arr.append("\n\n")
649
+ return ret_entities,confidences,subtypes
650
+
651
+
652
+ def sort_subtypes(self,subtypes):
653
+ sorted_subtypes = OrderedDict()
654
+ for ent in subtypes:
655
+ final_sorted_d = OrderedDict(sorted(subtypes[ent].items(), key=lambda kv: kv[1], reverse=True))
656
+ sorted_subtypes[ent] = list(final_sorted_d.keys())
657
+ return sorted_subtypes
658
+
659
+ def update_entities_with_subtypes(self,ret_entities,subtypes):
660
+ new_entities = []
661
+
662
+ for ent in ret_entities:
663
+ #if (len(ret_entities) == 1):
664
+ # new_entities.append(ent) #avoid creating a subtype for a single case
665
+ # return new_entities
666
+ if (ent in subtypes):
667
+ new_entities.append(ent + '[' + ','.join(subtypes[ent]) + ']')
668
+ else:
669
+ new_entities.append(ent)
670
+ return new_entities
671
+
672
+ def skip_untagged(self,term):
673
+ if (self.suppress_untagged == True and (term == "OTHER" or term == "UNTAGGED_ENTITY")):
674
+ return True
675
+ return False
676
+
677
+
678
+ def map_entities(self,arr,counts_arr,subtypes_dict):
679
+ ret_arr = []
680
+ new_counts_arr = []
681
+ for index,term in enumerate(arr):
682
+ if (self.skip_untagged(term)):
683
+ continue
684
+ ret_arr.append(self.entity_map[term])
685
+ new_counts_arr.append(int(counts_arr[index]))
686
+ if (self.entity_map[term] not in subtypes_dict):
687
+ subtypes_dict[self.entity_map[term]] = {}
688
+ if (term not in subtypes_dict[self.entity_map[term]]):
689
+ #subtypes_dict[self.entity_map[i]][i] = 1
690
+ subtypes_dict[self.entity_map[term]][term] = int(counts_arr[index])
691
+ else:
692
+ #subtypes_dict[self.entity_map[i]][i] += 1
693
+ subtypes_dict[self.entity_map[term]][term] += int(counts_arr[index])
694
+ return ret_arr,new_counts_arr
695
+
696
+ def get_entities_from_batch(self,inp_arr):
697
+ entities_arr = []
698
+ for i in range(len(inp_arr)):
699
+ entities_arr.append(inp_arr[i]["e"])
700
+ entities_arr.append(inp_arr[i]["e_count"])
701
+ return entities_arr
702
+
703
+
704
+ def get_entities_for_masked_position(self,inp_arr,descs,debug_str_arr,entity_info_dict):
705
+ entities = self.get_entities_from_batch(inp_arr)
706
+ debug_combined_arr =[]
707
+ desc_arr =[]
708
+ assert(len(descs) %2 == 0)
709
+ assert(len(entities) %2 == 0)
710
+ index = 0
711
+ for d,e in zip(descs,entities):
712
+ p_e = '/'.join(e.split('/')[:5])
713
+ debug_combined_arr.append(d + " " + p_e)
714
+ if (index % 2 == 0):
715
+ temp_dict = OrderedDict()
716
+ temp_dict["d"] = d
717
+ temp_dict["e"] = e
718
+ else:
719
+ temp_dict["mlm"] = d
720
+ temp_dict["l_score"] = e
721
+ desc_arr.append(temp_dict)
722
+ index += 1
723
+ debug_str_arr.append("\n" + ', '.join(debug_combined_arr))
724
+ print(debug_combined_arr)
725
+ entity_info_dict["descs"] = desc_arr
726
+ #debug_str_arr.append(' '.join(entities))
727
+ assert(len(entities) == len(descs))
728
+ entities,confidences,subtypes = self.aggregate_entities(entities,descs,debug_str_arr,entity_info_dict["entities"])
729
+ return entities,confidences,subtypes
730
+
731
+
732
+ #This is again a bad hack for prototyping purposes - extracting fields from a raw text output as opposed to a structured output like json
733
+ def extract_descs(self,text):
734
+ arr = text.split('\n')
735
+ desc_arr = []
736
+ if (len(arr) > 0):
737
+ for i,line in enumerate(arr):
738
+ if (line.startswith(DESC_HEAD)):
739
+ terms = line.split(':')
740
+ desc_arr = ' '.join(terms[1:]).strip().split()
741
+ break
742
+ return desc_arr
743
+
744
+
745
+ def generate_masked_sentences(self,terms_arr):
746
+ size = len(terms_arr)
747
+ sentence_arr = []
748
+ span_arr = []
749
+ i = 0
750
+ while (i < size):
751
+ term_info = terms_arr[i]
752
+ if (term_info[TAG_POS] in noun_tags):
753
+ skip = self.gen_sentence(sentence_arr,terms_arr,i)
754
+ i += skip
755
+ for j in range(skip):
756
+ span_arr.append(1)
757
+ else:
758
+ i += 1
759
+ span_arr.append(0)
760
+ #print(sentence_arr)
761
+ return sentence_arr,span_arr
762
+
763
+ def gen_sentence(self,sentence_arr,terms_arr,index):
764
+ size = len(terms_arr)
765
+ new_sent = []
766
+ for prefix,term in enumerate(terms_arr[:index]):
767
+ new_sent.append(term[WORD_POS])
768
+ i = index
769
+ skip = 0
770
+ while (i < size):
771
+ if (terms_arr[i][TAG_POS] in noun_tags):
772
+ skip += 1
773
+ i += 1
774
+ else:
775
+ break
776
+ new_sent.append(MASK_TAG)
777
+ i = index + skip
778
+ while (i < size):
779
+ new_sent.append(terms_arr[i][WORD_POS])
780
+ i += 1
781
+ assert(skip != 0)
782
+ sentence_arr.append(new_sent)
783
+ return skip
784
+
785
+
786
+
787
+
788
+
789
+
790
+
791
+
792
+ def run_test(file_name,obj):
793
+ rfp = open("results.txt","w")
794
+ dfp = open("debug.txt","w")
795
+ with open(file_name) as fp:
796
+ count = 1
797
+ for line in fp:
798
+ if (len(line) > 1):
799
+ print(str(count) + "] ",line,end='')
800
+ obj.tag_sentence(line,rfp,dfp)
801
+ count += 1
802
+ rfp.close()
803
+ dfp.close()
804
+
805
+
806
+ def tag_single_entity_in_sentence(file_name,obj):
807
+ rfp = open("results.txt","w")
808
+ dfp = open("debug.txt","w")
809
+ sfp = open("se_results.txt","w")
810
+ with open(file_name) as fp:
811
+ count = 1
812
+ for line in fp:
813
+ if (len(line) > 1):
814
+ print(str(count) + "] ",line,end='')
815
+ #entity_arr,span_arr,terms_arr,ner_str,debug_str = obj.tag_sentence(line,rfp,dfp,False) # False for json output
816
+ json_str = obj.tag_sentence(line,rfp,dfp,True) # True for json output
817
+ #print("*******************:",terms_arr[span_arr.index(1)][WORD_POS].rstrip(":"),entity_arr[0])
818
+ #sfp.write(terms_arr[span_arr.index(1)][WORD_POS].rstrip(":") + " " + entity_arr[0] + "\n")
819
+ count += 1
820
+ sfp.flush()
821
+ #pdb.set_trace()
822
+ rfp.close()
823
+ sfp.close()
824
+ dfp.close()
825
+
826
+
827
+
828
+
829
+ test_arr = [
830
+ "He felt New:__entity__ York:__entity__ has a chance to win this year's competition",
831
+ "Ajit rajasekharan is an engineer at nFerence:__entity__",
832
+ "Ajit:__entity__ rajasekharan is an engineer:__entity__ at nFerence:__entity__",
833
+ "Mesothelioma:__entity__ is caused by exposure to asbestos:__entity__",
834
+ "Fyodor:__entity__ Mikhailovich:__entity__ Dostoevsky:__entity__ was treated for Parkinsons",
835
+ "Ajit:__entity__ Rajasekharan:__entity__ is an engineer at nFerence",
836
+ "A eGFR:__entity__ below 60 indicates chronic kidney disease",
837
+ "A eGFR below 60:__entity__ indicates chronic kidney disease",
838
+ "A eGFR:__entity__ below 60:__entity__ indicates chronic:__entity__ kidney:__entity__ disease:__entity__",
839
+ "Ajit:__entity__ rajasekharan is an engineer at nFerence",
840
+ "Her hypophysitis secondary to ipilimumab was well managed with supplemental hormones",
841
+ "In Seattle:__entity__ , Pete Incaviglia 's grand slam with one out in the sixth snapped a tie and lifted the Baltimore Orioles past the Seattle Mariners , 5-2 .",
842
+ "engineer",
843
+ "Austin:__entity__ called",
844
+ "Paul Erdős died at 83",
845
+ "Imatinib mesylate is a drug and is used to treat nsclc",
846
+ "In Seattle , Pete Incaviglia 's grand slam with one out in the sixth snapped a tie and lifted the Baltimore Orioles past the Seattle Mariners , 5-2 .",
847
+ "It was Incaviglia 's sixth grand slam and 200th homer of his career .",
848
+ "Add Women 's singles , third round Lisa Raymond ( U.S. ) beat Kimberly Po ( U.S. ) 6-3 6-2 .",
849
+ "1880s marked the beginning of Jazz",
850
+ "He flew from New York to SFO",
851
+ "Lionel Ritchie was popular in the 1980s",
852
+ "Lionel Ritchie was popular in the late eighties",
853
+ "John Doe flew from New York to Rio De Janiro via Miami",
854
+ "He felt New York has a chance to win this year's competition",
855
+ "Bandolier - Budgie ' , a free itunes app for ipad , iphone and ipod touch , released in December 2011 , tells the story of the making of Bandolier in the band 's own words - including an extensive audio interview with Burke Shelley",
856
+ "In humans mutations in Foxp2 leads to verbal dyspraxia",
857
+ "The recent spread of Corona virus flu from China to Italy,Iran, South Korea and Japan has caused global concern",
858
+ "Hotel California topped the singles chart",
859
+ "Elon Musk said Telsa will open a manufacturing plant in Europe",
860
+ "He flew from New York to SFO",
861
+ "After studies at Hofstra University , He worked for New York Telephone before He was elected to the New York State Assembly to represent the 16th District in Northwest Nassau County ",
862
+ "Everyday he rode his bicycle from Rajakilpakkam to Tambaram",
863
+ "If he loses Saturday , it could devalue his position as one of the world 's great boxers , \" Panamanian Boxing Association President Ramon Manzanares said .",
864
+ "West Indian all-rounder Phil Simmons took four for 38 on Friday as Leicestershire beat Somerset by an innings and 39 runs in two days to take over at the head of the county championship .",
865
+ "they are his friends ",
866
+ "they flew from Boston to Rio De Janiro and had a mocha",
867
+ "he flew from Boston to Rio De Janiro and had a mocha",
868
+ "X,Y,Z are medicines"]
869
+
870
+
871
+ def test_canned_sentences(obj):
872
+ rfp = open("results.txt","w")
873
+ dfp = open("debug.txt","w")
874
+ pdb.set_trace()
875
+ for line in test_arr:
876
+ ret_val = obj.tag_sentence(line,rfp,dfp,True)
877
+ pdb.set_trace()
878
+ rfp.close()
879
+ dfp.close()
880
+
881
+ if __name__ == '__main__':
882
+ parser = argparse.ArgumentParser(description='main NER for a single model ',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
883
+ parser.add_argument('-input', action="store", dest="input",default="",help='Input file required for run options batch,single')
884
+ parser.add_argument('-config', action="store", dest="config", default=DEFAULT_CONFIG,help='config file path')
885
+ parser.add_argument('-option', action="store", dest="option",default="canned",help='Valid options are canned,batch,single. canned - test few canned sentences used in medium artice. batch - tag sentences in input file. Entities to be tagged are determing used POS tagging to find noun phrases. specific - tag specific entities in input file. The tagged word or phrases needs to be of the form w1:__entity_ w2:__entity_ Example:Her hypophysitis:__entity__ secondary to ipilimumab was well managed with supplemental:__entity__ hormones:__entity__')
886
+ results = parser.parse_args()
887
+
888
+ obj = UnsupNER(results.config)
889
+ if (results.option == "canned"):
890
+ test_canned_sentences(obj)
891
+ elif (results.option == "batch"):
892
+ if (len(results.input) == 0):
893
+ print("Input file needs to be specified")
894
+ else:
895
+ run_test(results.input,obj)
896
+ print("Tags and sentences are written in results.txt and debug.txt")
897
+ elif (results.option == "specific"):
898
+ if (len(results.input) == 0):
899
+ print("Input file needs to be specified")
900
+ else:
901
+ tag_single_entity_in_sentence(results.input,obj)
902
+ print("Tags and sentences are written in results.txt and debug.txt")
903
+ else:
904
+ print("Invalid argument:\n")
905
+ parser.print_help()
bbc/bbc_labels.txt ADDED
The diff for this file is too large to render. See raw diff
 
bbc/desc_bbc_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {"POS_SERVER_URL": "http://127.0.0.1:8073/",
2
+ "LOG_DESCS": "0",
3
+ "USE_CLS": "0",
4
+ "BASE_PATH":"./bbc/",
5
+ "COMMON_DESCS_FILE": "untagged_terms.txt"
6
+ }
bbc/ner_bbc_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {"POS_SERVER_URL": "http://127.0.0.1:8073/",
2
+ "DESC_SERVER_URL": "http://127.0.0.1:8088/dummy/0/",
3
+ "ENTITY_SERVER_URL": "http://127.0.0.1:8043/",
4
+ "EMAP_FILE": "entity_types_consolidated.txt",
5
+ "FULL_SENTENCE_TAG": "1",
6
+ "SUPPRESS_UNTAGGED": "1",
7
+ "BASE_PATH":"./bbc/",
8
+ "COMMON_DESCS_FILE": "untagged_terms.txt"}
bbc/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
bio/a100_labels.txt ADDED
The diff for this file is too large to render. See raw diff
 
bio/desc_a100_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {"POS_SERVER_URL": "http://127.0.0.1:8073/",
2
+ "LOG_DESCS": "0",
3
+ "USE_CLS": "1",
4
+ "BASE_PATH":"./bio/",
5
+ "COMMON_DESCS_FILE": "untagged_terms.txt"
6
+ }
bio/ner_a100_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {"POS_SERVER_URL": "http://127.0.0.1:8073/",
2
+ "DESC_SERVER_URL": "http://127.0.0.1:8087/dummy/0/",
3
+ "ENTITY_SERVER_URL": "http://127.0.0.1:8043/",
4
+ "EMAP_FILE": "entity_types_consolidated.txt",
5
+ "FULL_SENTENCE_TAG": "1",
6
+ "SUPPRESS_UNTAGGED": "1",
7
+ "BASE_PATH":"./bio/",
8
+ "COMMON_DESCS_FILE": "untagged_terms.txt"}
bio/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
common.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import sys
3
+
4
+ WORD_POS = 1
5
+ TAG_POS = 2
6
+ MASK_TAG = "__entity__"
7
+ INPUT_MASK_TAG = ":__entity__"
8
+ RESET_POS_TAG='RESET'
9
+
10
+
11
+ noun_tags = ['NFP','JJ','NN','FW','NNS','NNPS','JJS','JJR','NNP','POS','CD']
12
+ cap_tags = ['NFP','JJ','NN','FW','NNS','NNPS','JJS','JJR','NNP','PRP']
13
+
14
+
15
+ def detect_masked_positions(terms_arr):
16
+ sentence_arr,span_arr = generate_masked_sentences(terms_arr)
17
+ new_sent_arr = []
18
+ for i in range(len(terms_arr)):
19
+ new_sent_arr.append(terms_arr[i][WORD_POS])
20
+ return new_sent_arr,sentence_arr,span_arr
21
+
22
+ def generate_masked_sentences(terms_arr):
23
+ size = len(terms_arr)
24
+ sentence_arr = []
25
+ span_arr = []
26
+ i = 0
27
+ hack_for_no_nouns_case(terms_arr)
28
+ while (i < size):
29
+ term_info = terms_arr[i]
30
+ if (term_info[TAG_POS] in noun_tags):
31
+ skip = gen_sentence(sentence_arr,terms_arr,i)
32
+ i += skip
33
+ for j in range(skip):
34
+ span_arr.append(1)
35
+ else:
36
+ i += 1
37
+ span_arr.append(0)
38
+ #print(sentence_arr)
39
+ return sentence_arr,span_arr
40
+
41
+ def hack_for_no_nouns_case(terms_arr):
42
+ '''
43
+ This is just a hack for case user enters a sentence with no entity to be tagged specifically and the sentence has no nouns
44
+ Happens for odd inputs like a single word like "eg" etc.
45
+ Just make the first term as a noun to proceed.
46
+ '''
47
+ size = len(terms_arr)
48
+ i = 0
49
+ found = False
50
+ while (i < size):
51
+ term_info = terms_arr[i]
52
+ if (term_info[TAG_POS] in noun_tags):
53
+ found = True
54
+ break
55
+ else:
56
+ i += 1
57
+ if (not found and len(terms_arr) >= 1):
58
+ term_info = terms_arr[0]
59
+ term_info[TAG_POS] = noun_tags[0]
60
+
61
+
62
+ def gen_sentence(sentence_arr,terms_arr,index):
63
+ size = len(terms_arr)
64
+ new_sent = []
65
+ for prefix,term in enumerate(terms_arr[:index]):
66
+ new_sent.append(term[WORD_POS])
67
+ i = index
68
+ skip = 0
69
+ while (i < size):
70
+ if (terms_arr[i][TAG_POS] in noun_tags):
71
+ skip += 1
72
+ i += 1
73
+ else:
74
+ break
75
+ new_sent.append(MASK_TAG)
76
+ i = index + skip
77
+ while (i < size):
78
+ new_sent.append(terms_arr[i][WORD_POS])
79
+ i += 1
80
+ assert(skip != 0)
81
+ sentence_arr.append(new_sent)
82
+ return skip
83
+
84
+
85
+
86
+ def capitalize(terms_arr):
87
+ for i,term_tag in enumerate(terms_arr):
88
+ #print(term_tag)
89
+ if (term_tag[TAG_POS] in cap_tags):
90
+ word = term_tag[WORD_POS][0].upper() + term_tag[WORD_POS][1:]
91
+ term_tag[WORD_POS] = word
92
+ #print(terms_arr)
93
+
94
+ def set_POS_based_on_entities(sent):
95
+ terms_arr = []
96
+ sent_arr = sent.split()
97
+ for i,word in enumerate(sent_arr):
98
+ #print(term_tag)
99
+ term_tag = ['-']*5
100
+ if (word.endswith(INPUT_MASK_TAG)):
101
+ term_tag[TAG_POS] = noun_tags[0]
102
+ term_tag[WORD_POS] = word.replace(INPUT_MASK_TAG,"")
103
+ else:
104
+ term_tag[TAG_POS] = RESET_POS_TAG
105
+ term_tag[WORD_POS] = word
106
+ terms_arr.append(term_tag)
107
+ return terms_arr
108
+ #print(terms_arr)
109
+
110
+ def filter_common_noun_spans(span_arr,masked_sent_arr,terms_arr,common_descs):
111
+ ret_span_arr = span_arr.copy()
112
+ ret_masked_sent_arr = []
113
+ sent_index = 0
114
+ loop_span_index = 0
115
+ while (loop_span_index < len(span_arr)):
116
+ span_val = span_arr[loop_span_index]
117
+ orig_index = loop_span_index
118
+ if (span_val == 1):
119
+ curr_index = orig_index
120
+ is_all_common = True
121
+ while (curr_index < len(span_arr) and span_arr[curr_index] == 1):
122
+ term = terms_arr[curr_index]
123
+ if (term[WORD_POS].lower() not in common_descs):
124
+ is_all_common = False
125
+ curr_index += 1
126
+ loop_span_index = curr_index #note the loop scan index is updated
127
+ if (is_all_common):
128
+ curr_index = orig_index
129
+ print("Filtering common span: ",end='')
130
+ while (curr_index < len(span_arr) and span_arr[curr_index] == 1):
131
+ print(terms_arr[curr_index][WORD_POS],' ',end='')
132
+ ret_span_arr[curr_index] = 0
133
+ curr_index += 1
134
+ print()
135
+ sent_index += 1 # we are skipping a span
136
+ else:
137
+ ret_masked_sent_arr.append(masked_sent_arr[sent_index])
138
+ sent_index += 1
139
+ else:
140
+ loop_span_index += 1
141
+ return ret_masked_sent_arr,ret_span_arr
142
+
143
+ def normalize_casing(sent):
144
+ sent_arr = sent.split()
145
+ ret_sent_arr = []
146
+ for i,word in enumerate(sent_arr):
147
+ if (len(word) > 1):
148
+ norm_word = word[0] + word[1:].lower()
149
+ else:
150
+ norm_word = word[0]
151
+ ret_sent_arr.append(norm_word)
152
+ return ' '.join(ret_sent_arr)
153
+
common_descs.txt ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ a
2
+ all
3
+ an
4
+ and
5
+ any
6
+ are
7
+ as
8
+ at
9
+ away
10
+ be
11
+ beside
12
+ but
13
+ by
14
+ can
15
+ come
16
+ did
17
+ do
18
+ each
19
+ etc
20
+ far
21
+ free
22
+ get
23
+ gets
24
+ getting
25
+ give
26
+ given
27
+ gives
28
+ giving
29
+ go
30
+ goes
31
+ going
32
+ gonna
33
+ good
34
+ got
35
+ gotta
36
+ greatly
37
+ grow
38
+ growing
39
+ guess
40
+ had
41
+ has
42
+ how
43
+ in
44
+ is
45
+ it
46
+ its
47
+ itself
48
+ keep
49
+ keeps
50
+ kept
51
+ key
52
+ lack
53
+ led
54
+ let
55
+ lets
56
+ like
57
+ liked
58
+ likely
59
+ long
60
+ look
61
+ looking
62
+ looks
63
+ lose
64
+ loss
65
+ lost
66
+ lot
67
+ lots
68
+ lou
69
+ loud
70
+ made
71
+ make
72
+ matter
73
+ mean
74
+ meaning
75
+ means
76
+ meant
77
+ meet
78
+ meeting
79
+ meets
80
+ mere
81
+ merely
82
+ more
83
+ most
84
+ mostly
85
+ move
86
+ much
87
+ must
88
+ need
89
+ needed
90
+ needing
91
+ needs
92
+ new
93
+ next
94
+ nice
95
+ nobody
96
+ of
97
+ off
98
+ on
99
+ once
100
+ ongoing
101
+ only
102
+ or
103
+ place
104
+ placed
105
+ reach
106
+ same
107
+ saying
108
+ show
109
+ side
110
+ some
111
+ the
112
+ then
113
+ this
114
+ thence
115
+ thing
116
+ though
117
+ until
118
+ unto
119
+ usual
120
+ usually
121
+ wanna
122
+ want
123
+ wanted
124
+ wanting
125
+ wants
126
+ was
127
+ when
128
+ where
129
+ whereas
130
+ whereby
131
+ wherein
132
+ whether
133
+ which
134
+ while
135
+ whilst
136
+ whoever
137
+ whom
138
+ why
139
+ with
140
+ within
141
+ without
142
+ would
143
+ both
144
+ high
145
+ called
146
+ from
147
+ entitled
148
+ using
149
+ to
config_utils.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+
4
+
5
+ def write_config(configs,file_name='server_config.json'):
6
+ print(json.dumps(configs))
7
+ with open(file_name, 'w') as outfile:
8
+ json.dump(configs, outfile)
9
+
10
+
11
+ def read_config(file_name='server_config.json'):
12
+ try:
13
+ with open(file_name) as json_file:
14
+ data = json.load(json_file)
15
+ #print(data)
16
+ return data
17
+ except:
18
+ print("Unable to open config file:",file_name)
19
+ return {}
entity_types_consolidated.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ THERAPEUTIC_OR_PREVENTIVE_PROCEDURE DRUG/CHEMICAL_SUBSTANCE/HAZARDOUS_OR_POISONOUS_SUBSTANCE/ESTABLISHED_PHARMACOLOGIC_CLASS/CHEMICAL_CLASS/VITAMIN/LAB_PROCEDURE/SURGICAL_AND_MEDICAL_PROCEDURES/DIAGNOSTIC_PROCEDURE/LAB_TEST_COMPONENT/STUDY/DRUG_ADJECTIVE
2
+ DISEASE MENTAL_OR_BEHAVIORAL_DYSFUNCTION/CONGENITAL_ABNORMALITY/CELL_OR_MOLECULAR_DYSFUNCTION/DISEASE_ADJECTIVE
3
+ GENE PROTEIN/ENZYME/VIRAL_PROTEIN/RECEPTOR/PROTEIN_FAMILY/MOUSE_PROTEIN_FAMILY/MOUSE_GENE/NUCLEOTIDE_SEQUENCE/GENE_EXPRESSION_ADJECTIVE
4
+ BODY_PART_OR_ORGAN_COMPONENT BODY_LOCATION_OR_REGION/BODY_SUBSTANCE/CELL/CELL_LINE/CELL_COMPONENT/BIO_MOLECULE/METABOLITE/HORMONE/BODY_ADJECTIVE
5
+ ORGANISM_FUNCTION ORGAN_OR_TISSUE_FUNCTION/PHYSIOLOGIC_FUNCTION/CELL_FUNCTION/FUNCTION_ADJECTIVE
6
+ BIO SPECIES/BACTERIUM/VIRUS/BIO_ADJECTIVE
7
+ OBJECT PRODUCT/MEDICAL_DEVICE/DEVICE/DEVICE_ADJECTIVE
8
+ MEASURE NUMBER/TIME/SEQUENCE/MEASURE_ADJECTIVE
9
+ PERSON PERSON_ADJECTIVE
10
+ ORGANIZATION UNIV/GOV/EDU/ORGANIZATION_ADJECTIVE
11
+ ENT SPORT/MOV/MUSIC/ENT_ADJECTIVE
12
+ LOCATION LOCATION_ADJECTIVE
13
+ SOCIAL_CIRCUMSTANCES RELIGION/SOCIAL_CIRCUMSTANCES_ADJECTIVE
14
+ COLOR COLOR_ADJECTIVE
15
+ LANGUAGE LANGUAGE_ADJECTIVE
16
+ GRAMMAR_CONSTRUCT
17
+ OTHER
18
+ UNTAGGED_ENTITY
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ flair
2
+ st-annotated-text
3
+
untagged_terms.txt ADDED
File without changes