File size: 10,745 Bytes
5775680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
026e9d4
 
94b0e7c
5775680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94b0e7c
5775680
 
94b0e7c
5775680
 
 
 
 
 
 
 
 
 
 
026e9d4
 
 
 
5775680
 
94b0e7c
5775680
94b0e7c
5775680
 
026e9d4
 
 
 
5775680
 
 
 
 
026e9d4
c8002ae
5775680
 
 
d273f14
 
 
5775680
94b0e7c
5775680
 
 
 
d273f14
 
026e9d4
d3d39c9
5775680
 
 
d273f14
5775680
 
c63015b
94b0e7c
5775680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94b0e7c
5775680
 
 
 
 
94b0e7c
 
 
 
 
5775680
94b0e7c
5775680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc63d14
c576416
9802396
5775680
 
94b0e7c
9802396
5775680
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
import time
import streamlit as st
import torch
import string
from annotated_text import annotated_text

from flair.data import Sentence
from flair.models import SequenceTagger
from transformers import BertTokenizer, BertForMaskedLM
import BatchInference as bd
import batched_main_NER as ner
import aggregate_server_json as aggr
import json


DEFAULT_TOP_K = 20
SPECIFIC_TAG=":__entity__"



@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def POS_get_model(model_name):
  val = SequenceTagger.load(model_name) # Load the model
  return val
  
def getPos(s: Sentence):
  texts = []
  labels = []
  for t in s.tokens:
    for label in t.annotation_layers.keys():
      texts.append(t.text)
      labels.append(t.get_labels(label)[0].value)          
  return texts, labels
  
def getDictFromPOS(texts, labels):
  return [["dummy",t,l,"dummy","dummy" ] for t, l in zip(texts, labels)]
  
def decode(tokenizer, pred_idx, top_clean):
  ignore_tokens = string.punctuation + '[PAD]'
  tokens = []
  for w in pred_idx:
    token = ''.join(tokenizer.decode(w).split())
    if token not in ignore_tokens:
      tokens.append(token.replace('##', ''))
  return '\n'.join(tokens[:top_clean])

def encode(tokenizer, text_sentence, add_special_tokens=True):
  text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
    # if <mask> is the last token, append a "." so that models dont predict punctuation.
  if tokenizer.mask_token == text_sentence.split()[-1]:
    text_sentence += ' .'

    input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
    mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
  return input_ids, mask_idx

def get_all_predictions(text_sentence, top_clean=5):
    # ========================= BERT =================================
  input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
  with torch.no_grad():
    predict = bert_model(input_ids)[0]
  bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k).indices.tolist(), top_clean)
  return {'bert': bert}

def get_bert_prediction(input_text,top_k):
  try:
    input_text += ' <mask>'
    res = get_all_predictions(input_text, top_clean=int(top_k))
    return res
  except Exception as error:
    pass


def load_pos_model():
  checkpoint = "flair/pos-english"
  return  POS_get_model(checkpoint)
  

  

def init_session_states():
  if 'top_k' not in st.session_state:
    st.session_state['top_k'] = 20
  if 'pos_model' not in st.session_state:
    st.session_state['pos_model'] = None
  if 'bio_model' not in st.session_state:
    st.session_state['bio_model'] = None
  if 'ner_bio' not in st.session_state:
    st.session_state['ner_bio'] = None
  if 'aggr' not in st.session_state:
    st.session_state['aggr'] = None
  

    
    
def get_pos_arr(input_text,display_area):
   if (st.session_state['pos_model'] is None):
     display_area.text("Loading model 3 of 3.Loading POS model...")
     st.session_state['pos_model'] = load_pos_model()
   s = Sentence(input_text)
   st.session_state['pos_model'].predict(s)
   texts, labels = getPos(s)
   pos_results = getDictFromPOS(texts, labels)
   return pos_results
 
def perform_inference(text,display_area):
  
  if (st.session_state['bio_model'] is None):
    display_area.text("Loading model 1 of 2. Bio model...")
    st.session_state['bio_model'] = bd.BatchInference("bio/desc_a100_config.json",'ajitrajasekharan/biomedical',False,False,DEFAULT_TOP_K,True,True,       "bio/","bio/a100_labels.txt",False)
    

    
  #Load POS model if needed and gets POS tags
  if (SPECIFIC_TAG not in text):
    pos_arr = get_pos_arr(text,display_area)
  else:
    pos_arr = None
    
  if (st.session_state['ner_bio'] is None):
    display_area.text("Initializing BIO module...")
    st.session_state['ner_bio'] = ner.UnsupNER("bio/ner_a100_config.json")
    
  if (st.session_state['aggr'] is None):
    display_area.text("Initializing Aggregation module...")
    st.session_state['aggr'] = aggr.AggregateNER("./ensemble_config.json")
    
  
  
  display_area.text("Getting predictions from BIO model...")
  bio_descs = st.session_state['bio_model'].get_descriptors(text,pos_arr)
  display_area.text("Computing BIO results...")
  bio_ner = st.session_state['ner_bio'].tag_sentence_service(text,bio_descs)
  
  obj = json.loads(bio_ner)
  combined_arr = [obj,obj]
  aggregate_results = st.session_state['aggr'].fetch_all(text,combined_arr)
  return aggregate_results
  

sent_arr = [
"Lou Gehrig who works for XCorp and lives in New York suffers from Parkinson's ",  
"Parkinson who works for XCorp and lives in New York suffers from Lou Gehrig's", 
"Her hypophysitis secondary to ipilimumab was well managed with supplemental hormones",
"the portfolio manager of the new cryptocurrency firm underwent a bone marrow biopsy for AML",
"lou gehrig was diagnosed with Parkinson's ",  
"A eGFR below 60 indicates chronic kidney disease", 
"Overexpression of EGFR occurs across a wide range of different cancers", 
"He was diagnosed with non small cell lung cancer",  
"There are no treatment options specifically indicated for ACD and physicians must utilize agents approved for other dermatology conditions", 
"As ACD has been implicated in apoptosis-resistant glioblastoma (GBM), there is a high medical need for identifying novel ACD-inducing drugs  ", 
"Patients treated with anticancer chemotherapy drugs ( ACD ) are vulnerable to infectious diseases due to immunosuppression and to the direct impact of ACD on their intestinal microbiota ", 
"In the LASOR trial , increasing daily imatinib dose from 400 to 600mg induced MMR at 12 and 24 months in 25% and 36% of the patients,        respectively, who had suboptimal cytogenetic responses "
]


sent_arr_masked = [
"Lou:__entity__ Gehrig:__entity__ who works for XCorp and lives in New York suffers from Parkinson's:__entity__ ",  
"Parkinson:__entity__ who works for XCorp and lives in New York suffers from Lou Gehrig's:__entity__", 
"Her hypophysitis:__entity__ secondary to ipilimumab:__entity__ was well managed with supplemental:__entity__ hormones:__entity__",
"the portfolio manager of the new cryptocurrency firm underwent a bone marrow biopsy for AML",
"lou:__entity__ gehrig:__entity__ was diagnosed with Parkinson's:__entity__ ",  
"A eGFR:__entity__ below 60 indicates chronic kidney disease", 
"Overexpression of EGFR:__entity__ occurs across a wide range of different cancers", 
"He was diagnosed with non:__entity__ small:__entity__ cell:__entity__ lung:__entity__ cancer:__entity__",  
"There are no treatment options specifically indicated for ACD:__entity__ and physicians must utilize agents approved for other dermatology conditions", 
"As ACD:__entity__ has been implicated in apoptosis-resistant glioblastoma (GBM), there is a high medical need for identifying novel ACD-inducing drugs  ", 
"Patients treated with anticancer chemotherapy drugs ( ACD:__entity__ ) are vulnerable to infectious diseases due to immunosuppression and to the direct impact of ACD on their intestinal microbiota ", 
"In the LASOR:__entity__ trial:__entity__ , increasing daily imatinib dose from 400 to 600mg induced MMR at 12 and 24 months in 25% and 36% of the patients,        respectively, who had suboptimal cytogenetic responses "
]

def init_selectbox():
  return st.selectbox(
     'Choose any of the sentences in pull-down below',
     sent_arr,key='my_choice')
  
   
def on_text_change():
  text = st.session_state.my_text
  print("in callback: " + text)
  perform_inference(text)

def main():
  try:

    init_session_states()
    
    st.markdown("<h3 style='text-align: center;'>Biomedical NER using a pretrained model with <a href='https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html'>no fine tuning</a></h3>", unsafe_allow_html=True)
    #st.markdown("""
    #<h3 style="font-size:16px; color: #ff0000; text-align: center"><b>App under construction... (not in working condition yet)</b></h3>
  #""", unsafe_allow_html=True)
   
    
    #st.markdown("""
    #<p style="text-align:center;"><img src="https://ajitrajasekharan.github.io/images/1.png" width="700"></p>
   # <br/>
   # <br/>
  #""", unsafe_allow_html=True)
  
    st.write("This app uses 2 models.  A  Bert model  pretrained (**no fine tuning**) on biomedical corpus, and a POS tagger")
    
    
    with st.form('my_form'):
      selected_sentence = init_selectbox()
      text_input = st.text_area(label='Type any sentence below',value="")
      submit_button = st.form_submit_button('Submit')
      input_status_area = st.empty()
      display_area = st.empty()
      if 	submit_button:
            start = time.time()
            if (len(text_input) == 0):
              text_input = sent_arr_masked[sent_arr.index(selected_sentence)]
            input_status_area.text("Input sentence:  " + text_input)
            results = perform_inference(text_input,display_area)
            display_area.empty()
            with display_area.container():
              st.text(f"prediction took {time.time() - start:.2f}s")
              st.json(results)
              
              
              
 

    st.markdown("""
    <small style="font-size:16px; color: #8f8f8f; text-align: left"><i><b>Note:</b> The example sentences in the pull-down above only test biomedical entities. PHI entities are not tested since this model does not perform well on PHI entities. To see valid predictions for both biomedical and PHI entities  <a href='https://huggingface.co/spaces/ajitrajasekharan/NER-Biomedical-PHI-Ensemble'   target='_blank'>use this ensemble app</a></i></small>
  """, unsafe_allow_html=True)
   
    st.markdown("""
    <small style="font-size:16px; color: #7f7f7f; text-align: left"><br/><br/>Models used: <br/>(1) <a href='https://huggingface.co/ajitrajasekharan/biomedical' target='_blank'>Biomedical model</a> pretrained on Pubmed,Clinical trials and BookCorpus subset.<br/>(2) Flair POS tagger</small>
  """, unsafe_allow_html=True)
    st.markdown("""
    <h3 style="font-size:16px; color: #9f9f9f; text-align: center"><b> <a href='https://huggingface.co/spaces/ajitrajasekharan/Qualitative-pretrained-model-evaluation'   target='_blank'>App link to examine pretrained models</a> used to perform NER without fine tuning</b></h3>
  """, unsafe_allow_html=True)
    st.markdown("""
    <h3 style="font-size:16px; color: #9f9f9f; text-align: center">Github <a href='http://github.com/ajitrajasekharan/unsupervised_NER' target='_blank'>link to same working code </a>(without UI) as separate microservices</h3>
  """, unsafe_allow_html=True)

  except Exception as e:
    print("Some error occurred in main") 
    st.exception(e) 
	
if __name__ == "__main__":
   main()