File size: 7,845 Bytes
854a552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2485275
854a552
 
 
 
 
 
2485275
854a552
2485275
854a552
 
2485275
854a552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2485275
 
854a552
 
 
 
 
 
 
 
 
 
 
 
 
2485275
 
854a552
 
281974e
854a552
 
 
 
 
 
2485275
854a552
2485275
 
854a552
281974e
854a552
2485275
854a552
 
 
 
 
2485275
 
854a552
 
 
2485275
 
 
 
90a2a07
854a552
 
 
2485275
90a2a07
854a552
90a2a07
2485275
854a552
 
281974e
 
2485275
 
 
 
 
 
 
 
854a552
 
 
 
281974e
 
2485275
 
 
 
 
 
 
 
854a552
 
 
 
 
 
2485275
 
854a552
 
 
 
 
 
 
 
 
2485275
90a2a07
 
2485275
 
 
90a2a07
2485275
 
854a552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2485275
 
 
 
854a552
 
 
 
 
2485275
854a552
90a2a07
854a552
 
 
 
 
 
 
 
 
2485275
 
 
854a552
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import time
import streamlit as st
import torch
import string
from annotated_text import annotated_text

from flair.data import Sentence
from flair.models import SequenceTagger
from transformers import BertTokenizer, BertForMaskedLM
import BatchInference as bd
import batched_main_NER as ner
import aggregate_server_json as aggr
import json


DEFAULT_TOP_K = 20
SPECIFIC_TAG=":__entity__"



@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def POS_get_model(model_name):
  val = SequenceTagger.load(model_name) # Load the model
  return val

def getPos(s: Sentence):
  texts = []
  labels = []
  for t in s.tokens:
    for label in t.annotation_layers.keys():
      texts.append(t.text)
      labels.append(t.get_labels(label)[0].value)
  return texts, labels

def getDictFromPOS(texts, labels):
  return [["dummy",t,l,"dummy","dummy" ] for t, l in zip(texts, labels)]

def decode(tokenizer, pred_idx, top_clean):
  ignore_tokens = string.punctuation + '[PAD]'
  tokens = []
  for w in pred_idx:
    token = ''.join(tokenizer.decode(w).split())
    if token not in ignore_tokens:
      tokens.append(token.replace('##', ''))
  return '\n'.join(tokens[:top_clean])

def encode(tokenizer, text_sentence, add_special_tokens=True):
  text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
    # if <mask> is the last token, append a "." so that models dont predict punctuation.
  if tokenizer.mask_token == text_sentence.split()[-1]:
    text_sentence += ' .'

    input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
    mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
  return input_ids, mask_idx

def get_all_predictions(text_sentence, top_clean=5):
    # ========================= BERT =================================
  input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
  with torch.no_grad():
    predict = bert_model(input_ids)[0]
  bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k).indices.tolist(), top_clean)
  return {'bert': bert}

def get_bert_prediction(input_text,top_k):
  try:
    input_text += ' <mask>'
    res = get_all_predictions(input_text, top_clean=int(top_k))
    return res
  except Exception as error:
    pass


def load_pos_model():
  checkpoint = "flair/pos-english"
  return  POS_get_model(checkpoint)




def init_session_states():
  if 'top_k' not in st.session_state:
    st.session_state['top_k'] = 20
  if 'pos_model' not in st.session_state:
    st.session_state['pos_model'] = None
  if 'phi_model' not in st.session_state:
    st.session_state['phi_model'] = None
  if 'ner_phi' not in st.session_state:
    st.session_state['ner_phi'] = None
  if 'aggr' not in st.session_state:
    st.session_state['aggr'] = None



def get_pos_arr(input_text,display_area):
   if (st.session_state['pos_model'] is None):
     display_area.text("Loading model 2 of 2.Loading POS model...")
     st.session_state['pos_model'] = load_pos_model()
   s = Sentence(input_text)
   st.session_state['pos_model'].predict(s)
   texts, labels = getPos(s)
   pos_results = getDictFromPOS(texts, labels)
   return pos_results

def perform_inference(text,display_area):


  if (st.session_state['phi_model'] is None):
    display_area.text("Loading model 1 of 2. PHI model...")
    st.session_state['phi_model'] = bd.BatchInference("bbc/desc_bbc_config.json",'bert-base-cased',False,False,DEFAULT_TOP_K,True,True,       "bbc/","bbc/bbc_labels.txt",False)

  #Load POS model if needed and gets POS tags
  if (SPECIFIC_TAG not in text):
    pos_arr = get_pos_arr(text,display_area)
  else:
    pos_arr = None


  if (st.session_state['ner_phi'] is None):
    display_area.text("Initializing PHI module...")
    st.session_state['ner_phi'] = ner.UnsupNER("bbc/ner_bbc_config.json")




 
  display_area.text("Getting results from PHI model...")
  phi_results = st.session_state['phi_model'].get_descriptors(text,pos_arr)
  display_area.text("Aggregating BIO & PHI results...")

  phi_ner = st.session_state['ner_phi'].tag_sentence_service(text,phi_results)

  return phi_ner


sent_arr = [
"John Doe flew from New York to Rio De Janiro ",
"In 2020, John participated in the Winter Olympics and came third in Ice hockey",
"Stanford called",
"I met my girl friends at the pub ",
"I met my New York friends at the pub",
"I met my XCorp friends at the pub",
"I met my two friends at the pub",
"The sky turned dark in advance of the storm that was coming from the east ",
"She loves to watch Sunday afternoon football with her family ",
"Paul Erdos died at 83 "
]


sent_arr_masked = [
"John:__entity__ Doe:__entity__ flew from New York to Rio:__entity__ De:__entity__ Janiro:__entity__ ",
"In 2020:__entity__, Catherine:__entity__ Zeta:__entity__ Jones:__entity__ participated in the Winter:__entity__ Olympics:__entity__ and came third in Ice:__entity__ hockey:__entity__",
"Stanford:__entity__ called",
"I met my girl:__entity__ friends at the pub ",
"I met my New:__entity__ York:__entity__ friends at the pub",
"I met my XCorp:__entity__ friends at the pub",
"I met my two:__entity__ friends at the pub",
"The sky turned dark:__entity__ in advance of the storm that was coming from the east ",
"She loves to watch Sunday afternoon football:__entity__ with her family ",
"Paul:__entity__ Erdos:__entity__ died at 83:__entity__ "
]

def init_selectbox():
  return st.selectbox(
     'Choose any of the sentences in pull-down below',
     sent_arr,key='my_choice')


def on_text_change():
  text = st.session_state.my_text
  print("in callback: " + text)
  perform_inference(text)

def main():
  try:

    init_session_states()

    st.markdown("<h3 style='text-align: center;'>NER using pretrained models with <a href='https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html'>no fine tuning</a><br/><br/></h3>", unsafe_allow_html=True)




    st.write("This app uses 2 models.  Bert-base-cased(**no fine tuning**) and a POS tagger")


    with st.form('my_form'):
      selected_sentence = init_selectbox()
      text_input = st.text_area(label='Type any sentence below',value="")
      submit_button = st.form_submit_button('Submit')
      input_status_area = st.empty()
      display_area = st.empty()
      if 	submit_button:
            start = time.time()
            if (len(text_input) == 0):
              text_input = sent_arr_masked[sent_arr.index(selected_sentence)]
            input_status_area.text("Input sentence:  " + text_input)
            results = perform_inference(text_input,display_area)
            display_area.empty()
            with display_area.container():
              st.text(f"prediction took {time.time() - start:.2f}s")
              st.json(results)





    #input_text = st.text_area(
    #  label="Type any sentence",
   #   on_change=on_text_change,key='my_text'
   # )

    st.markdown("""
    <small style="font-size:16px; color: #7f7f7f; text-align: left"><br/><br/>Models used: <br/>(1) Bert-base-cased (for PHI entities - Person/location/organization etc.)<br/>(2) Flair POS tagger</small>
  #""", unsafe_allow_html=True)
    st.markdown("""
    <h3 style="font-size:16px; color: #9f9f9f; text-align: center"><b> <a href='https://huggingface.co/spaces/ajitrajasekharan/Qualitative-pretrained-model-evaluation'   target='_blank'>App link to examine pretrained models</a> used to perform NER without fine tuning</b></h3>
  """, unsafe_allow_html=True)
    st.markdown("""
    <h3 style="font-size:16px; color: #9f9f9f; text-align: center">Github <a href='http://github.com/ajitrajasekharan/unsupervised_NER' target='_blank'>link to same working code </a>(without UI) as separate microservices</h3>
  """, unsafe_allow_html=True)

  except Exception as e:
    print("Some error occurred in main")
    st.exception(e)

if __name__ == "__main__":
   main()