Spaces:
Runtime error
Runtime error
import time | |
import streamlit as st | |
import torch | |
import string | |
from annotated_text import annotated_text | |
from flair.data import Sentence | |
from flair.models import SequenceTagger | |
from transformers import BertTokenizer, BertForMaskedLM | |
import BatchInference as bd | |
import batched_main_NER as ner | |
import aggregate_server_json as aggr | |
import json | |
DEFAULT_TOP_K = 20 | |
SPECIFIC_TAG=":__entity__" | |
def POS_get_model(model_name): | |
val = SequenceTagger.load(model_name) # Load the model | |
return val | |
def getPos(s: Sentence): | |
texts = [] | |
labels = [] | |
for t in s.tokens: | |
for label in t.annotation_layers.keys(): | |
texts.append(t.text) | |
labels.append(t.get_labels(label)[0].value) | |
return texts, labels | |
def getDictFromPOS(texts, labels): | |
return [["dummy",t,l,"dummy","dummy" ] for t, l in zip(texts, labels)] | |
def decode(tokenizer, pred_idx, top_clean): | |
ignore_tokens = string.punctuation + '[PAD]' | |
tokens = [] | |
for w in pred_idx: | |
token = ''.join(tokenizer.decode(w).split()) | |
if token not in ignore_tokens: | |
tokens.append(token.replace('##', '')) | |
return '\n'.join(tokens[:top_clean]) | |
def encode(tokenizer, text_sentence, add_special_tokens=True): | |
text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token) | |
# if <mask> is the last token, append a "." so that models dont predict punctuation. | |
if tokenizer.mask_token == text_sentence.split()[-1]: | |
text_sentence += ' .' | |
input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)]) | |
mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0] | |
return input_ids, mask_idx | |
def get_all_predictions(text_sentence, top_clean=5): | |
# ========================= BERT ================================= | |
input_ids, mask_idx = encode(bert_tokenizer, text_sentence) | |
with torch.no_grad(): | |
predict = bert_model(input_ids)[0] | |
bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k).indices.tolist(), top_clean) | |
return {'bert': bert} | |
def get_bert_prediction(input_text,top_k): | |
try: | |
input_text += ' <mask>' | |
res = get_all_predictions(input_text, top_clean=int(top_k)) | |
return res | |
except Exception as error: | |
pass | |
def load_pos_model(): | |
checkpoint = "flair/pos-english" | |
return POS_get_model(checkpoint) | |
def init_session_states(): | |
if 'top_k' not in st.session_state: | |
st.session_state['top_k'] = 20 | |
if 'pos_model' not in st.session_state: | |
st.session_state['pos_model'] = None | |
if 'phi_model' not in st.session_state: | |
st.session_state['phi_model'] = None | |
if 'ner_phi' not in st.session_state: | |
st.session_state['ner_phi'] = None | |
if 'aggr' not in st.session_state: | |
st.session_state['aggr'] = None | |
def get_pos_arr(input_text,display_area): | |
if (st.session_state['pos_model'] is None): | |
display_area.text("Loading model 2 of 2.Loading POS model...") | |
st.session_state['pos_model'] = load_pos_model() | |
s = Sentence(input_text) | |
st.session_state['pos_model'].predict(s) | |
texts, labels = getPos(s) | |
pos_results = getDictFromPOS(texts, labels) | |
return pos_results | |
def perform_inference(text,display_area): | |
if (st.session_state['phi_model'] is None): | |
display_area.text("Loading model 1 of 2. BERT model...") | |
st.session_state['phi_model'] = bd.BatchInference("bbc/desc_bbc_config.json",'bert-base-cased',False,False,DEFAULT_TOP_K,True,True, "bbc/","bbc/bbc_labels.txt",False) | |
#Load POS model if needed and gets POS tags | |
if (SPECIFIC_TAG not in text): | |
pos_arr = get_pos_arr(text,display_area) | |
else: | |
pos_arr = None | |
if (st.session_state['ner_phi'] is None): | |
display_area.text("Initializing BERT module...") | |
st.session_state['ner_phi'] = ner.UnsupNER("bbc/ner_bbc_config.json") | |
if (st.session_state['aggr'] is None): | |
display_area.text("Initializing Aggregation modeule...") | |
st.session_state['aggr'] = aggr.AggregateNER("./ensemble_config.json") | |
display_area.text("Getting predictions from BERT model...") | |
phi_results = st.session_state['phi_model'].get_descriptors(text,pos_arr) | |
display_area.text("Computing NER results...") | |
display_area.text("Consolidating responses...") | |
phi_ner = st.session_state['ner_phi'].tag_sentence_service(text,phi_results) | |
obj = phi_ner | |
combined_arr = [obj,obj] | |
aggregate_results = st.session_state['aggr'].fetch_all(text,combined_arr) | |
return aggregate_results | |
sent_arr = [ | |
"Washington who resigned from Washington flew to Washington", | |
"John Doe flew from New York to Rio De Janiro ", | |
"In 2020, John participated in the Winter Olympics and came third in Ice hockey", | |
"Stanford called", | |
"I met my girl friends at the pub ", | |
"I met my New York friends at the pub", | |
"I met my XCorp friends at the pub", | |
"I met my two friends at the pub", | |
"The sky turned dark in advance of the storm that was coming from the east ", | |
"She loves to watch Sunday afternoon football with her family ", | |
"The United States has the largest prison population in the world, and the highest per-capita incarceration rate", | |
"Paul Erdos died at 83 " | |
] | |
sent_arr_masked = [ | |
"Washington:__entity__ who resigned from Washington:__entity__ flew to Washington:__entity__", | |
"John:__entity__ Doe:__entity__ flew from New:__entity__ York:__entity__ to Rio:__entity__ De:__entity__ Janiro:__entity__ ", | |
"In 2020:__entity__, Catherine:__entity__ Zeta:__entity__ Jones:__entity__ participated in the Winter:__entity__ Olympics:__entity__ and came third in Ice:__entity__ hockey:__entity__", | |
"Stanford:__entity__ called", | |
"I met my girl:__entity__ friends at the pub ", | |
"I met my New:__entity__ York:__entity__ friends at the pub", | |
"I met my XCorp:__entity__ friends at the pub", | |
"I met my two:__entity__ friends at the pub", | |
"The sky turned dark:__entity__ in advance of the storm that was coming from the east ", | |
"She loves to watch Sunday afternoon football:__entity__ with her family ", | |
"The United:__entity__ States:__entity__ has the largest prison population in the world, and the highest per-capita incarceration:__entity__ rate:__entity__", | |
"Paul:__entity__ Erdos:__entity__ died at 83:__entity__ " | |
] | |
def init_selectbox(): | |
return st.selectbox( | |
'Choose any of the sentences in pull-down below', | |
sent_arr,key='my_choice') | |
def on_text_change(): | |
text = st.session_state.my_text | |
print("in callback: " + text) | |
perform_inference(text) | |
def main(): | |
try: | |
init_session_states() | |
st.markdown("<h4 style='text-align: center;'>NER of PERSON,LOCATION,ORG etc.</h4>", unsafe_allow_html=True) | |
st.markdown("<h5 style='text-align: center;'>Using a pretrained BERT model with <a href='https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html'>no fine tuning</a><br/><br/></h5>", unsafe_allow_html=True) | |
st.write("This app uses 2 models. Bert-base-cased(**no fine tuning**) and a POS tagger") | |
with st.form('my_form'): | |
selected_sentence = init_selectbox() | |
text_input = st.text_area(label='Type any sentence below',value="") | |
submit_button = st.form_submit_button('Submit') | |
input_status_area = st.empty() | |
display_area = st.empty() | |
if submit_button: | |
start = time.time() | |
if (len(text_input) == 0): | |
text_input = sent_arr_masked[sent_arr.index(selected_sentence)] | |
input_status_area.text("Input sentence: " + text_input) | |
results = perform_inference(text_input,display_area) | |
display_area.empty() | |
with display_area.container(): | |
st.text(f"prediction took {time.time() - start:.2f}s") | |
st.json(results) | |
#input_text = st.text_area( | |
# label="Type any sentence", | |
# on_change=on_text_change,key='my_text' | |
# ) | |
st.markdown(""" | |
<small style="font-size:16px; color: #7f7f7f; text-align: left"><br/><br/>Models used: <br/>(1) Bert-base-cased (for PHI entities - Person/location/organization etc.)<br/>(2) Flair POS tagger</small> | |
#""", unsafe_allow_html=True) | |
st.markdown(""" | |
<h3 style="font-size:16px; color: #9f9f9f; text-align: center"><b> <a href='https://huggingface.co/spaces/ajitrajasekharan/Qualitative-pretrained-model-evaluation' target='_blank'>App link to examine pretrained models</a> used to perform NER without fine tuning</b></h3> | |
""", unsafe_allow_html=True) | |
st.markdown(""" | |
<h3 style="font-size:16px; color: #9f9f9f; text-align: center">Github <a href='http://github.com/ajitrajasekharan/unsupervised_NER' target='_blank'>link to same working code </a>(without UI) as separate microservices</h3> | |
""", unsafe_allow_html=True) | |
except Exception as e: | |
print("Some error occurred in main") | |
st.exception(e) | |
if __name__ == "__main__": | |
main() | |