SparkNLP_NER / app.py
aemin's picture
Upload app.py
f82b177
raw
history blame
4.65 kB
import streamlit as st
st.set_page_config(
layout="centered", # Can be "centered" or "wide". In the future also "dashboard", etc.
initial_sidebar_state="auto", # Can be "auto", "expanded", "collapsed"
page_title='Extractive Summarization', # String or None. Strings get appended with "• Streamlit".
page_icon='./favicon.png', # String, anything supported by st.image, or None.
)
import pandas as pd
import numpy as np
import os
import sys
sys.path.append(os.path.abspath('./'))
import streamlit_apps_config as config
from streamlit_ner_output import show_html2, jsl_display_annotations, get_color
import sparknlp
from sparknlp.base import *
from sparknlp.annotator import *
from pyspark.sql import functions as F
from sparknlp_display import NerVisualizer
from pyspark.ml import Pipeline
from pyspark.sql.types import StringType
spark= sparknlp.start()
## Marking down NER Style
st.markdown(config.STYLE_CONFIG, unsafe_allow_html=True)
root_path = config.project_path
########## To Remove the Main Menu Hamburger ########
hide_menu_style = """
<style>
#MainMenu {visibility: hidden;}
</style>
"""
st.markdown(hide_menu_style, unsafe_allow_html=True)
########## Side Bar ########
## loading logo(newer version with href)
import base64
@st.cache(allow_output_mutation=True)
def get_base64_of_bin_file(bin_file):
with open(bin_file, 'rb') as f:
data = f.read()
return base64.b64encode(data).decode()
@st.cache(allow_output_mutation=True)
def get_img_with_href(local_img_path, target_url):
img_format = os.path.splitext(local_img_path)[-1].replace('.', '')
bin_str = get_base64_of_bin_file(local_img_path)
html_code = f'''
<a href="{target_url}">
<img height="90%" width="90%" src="data:image/{img_format};base64,{bin_str}" />
</a>'''
return html_code
logo_html = get_img_with_href('./jsl-logo.png', 'https://www.johnsnowlabs.com/')
st.sidebar.markdown(logo_html, unsafe_allow_html=True)
#sidebar info
model_name= ["nerdl_fewnerd_100d"]
st.sidebar.title("Pretrained model to test")
selected_model = st.sidebar.selectbox("", model_name)
######## Main Page #########
app_title= "Detect up to 8 entity types in general domain texts"
app_description= "Named Entity Recognition model aimed to detect up to 8 entity types from general domain texts. This model was trained on the Few-NERD/inter public dataset using Spark NLP, and is available in Spark NLP Models hub (https://nlp.johnsnowlabs.com/models)"
st.title(app_title)
st.markdown("<h2>"+app_description+"</h2>" , unsafe_allow_html=True)
if selected_model == "nerdl_fewnerd_100d":
st.markdown("**`PERSON`** **,** **`ORGANIZATION`** **,** **`LOCATION`** **,** **`ART`** **,** **`BUILDING`** **,** **`PRODUCT`** **,** **`EVENT`** **,** **`OTHER`**", unsafe_allow_html=True)
st.subheader("")
#### Running model and creating pipeline
st.cache(allow_output_mutation=True)
def get_pipeline(text):
documentAssembler = DocumentAssembler()\
.setInputCol("text")\
.setOutputCol("document")
sentenceDetector= SentenceDetector()\
.setInputCols(["document"])\
.setOutputCol("sentence")
tokenizer = Tokenizer()\
.setInputCols(["sentence"])\
.setOutputCol("token")
embeddings= WordEmbeddingsModel.pretrained("glove_100d")\
.setInputCols(["sentence", "token"])\
.setOutputCol("embeddings")
ner= NerDLModel.pretrained("nerdl_fewnerd_100d")\
.setInputCols(["document", "token", "embeddings"])\
.setOutputCol("ner")
ner_converter= NerConverter()\
.setInputCols(["sentence", "token", "ner"])\
.setOutputCol("ner_chunk")
pipeline = Pipeline(
stages = [
documentAssembler,
sentenceDetector,
tokenizer,
embeddings,
ner,
ner_converter
])
empty_df = spark.createDataFrame([[""]]).toDF("text")
pipeline_model = pipeline.fit(empty_df)
text_df= spark.createDataFrame(pd.DataFrame({"text": [text]}))
result= pipeline_model.transform(text_df).toPandas()
return result
text= st.text_input("Type here your text and press enter to run:")
result= get_pipeline(text)
#Displaying Ner Visualization
df= pd.DataFrame({"ner_chunk": result["ner_chunk"].iloc[0]})
labels_set = set()
for i in df['ner_chunk'].values:
labels_set.add(i[4]['entity'])
labels_set = list(labels_set)
labels = st.sidebar.multiselect(
"NER Labels", options=labels_set, default=list(labels_set)
)
show_html2(text, df, labels, "Text annotated with identified Named Entities")