model_demo / app.py
alanakbik's picture
Make compatible with older streamlit version
b3c2743
raw
history blame
3.69 kB
import spacy.displacy
import streamlit as st
from flair.models import SequenceTagger
from flair.splitter import SegtokSentenceSplitter
from colorhash import ColorHash
# st.title("Flair NER Demo")
st.set_page_config(layout="centered")
# models to choose from
model_map = {
"find Entities (default)": "ner",
"find Entities (18-class)": "ner-ontonotes",
"find Parts-of-Speech": "upos",
}
# Block 1: Users can select a model
st.subheader("Select Model")
selected_model_id = st.selectbox("This is a check box",
model_map.keys(),
label_visibility="collapsed",
)
# Block 2: Users can input text
st.subheader("Input your text here")
input_text = st.text_area('Write or Paste Text Below',
value="George was born in Washington.",
height=128,
max_chars=None,
label_visibility="collapsed")
@st.cache(allow_output_mutation=True)
def get_model(model_name):
return SequenceTagger.load(model_map[model_name])
def get_html(html: str):
WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; margin-bottom: 2.5rem">{}</div>"""
html = html.replace("\n", " ")
return WRAPPER.format(html)
def color_variant(hex_color, brightness_offset=1):
""" takes a color like #87c95f and produces a lighter or darker variant
taken from: https://chase-seibert.github.io/blog/2011/07/29/python-calculate-lighterdarker-rgb-colors.html
"""
if len(hex_color) != 7:
raise Exception("Passed %s into color_variant(), needs to be in #87c95f format." % hex_color)
rgb_hex = [hex_color[x:x + 2] for x in [1, 3, 5]]
new_rgb_int = [int(hex_value, 16) + brightness_offset for hex_value in rgb_hex]
new_rgb_int = [min([255, max([0, i])]) for i in new_rgb_int] # make sure new values are between 0 and 255
# hex() produces "0x88", we want just "88"
return "#" + "".join([hex(i)[2:] for i in new_rgb_int])
# Block 3: Output is displayed
button_clicked = st.button("**Click here** to tag the input text", key=None)
if button_clicked:
# get a sentence splitter and split text into sentences
splitter = SegtokSentenceSplitter()
sentences = splitter.split(input_text)
# get the model and predict
model = get_model(selected_model_id)
model.predict(sentences)
spacy_display = {"ents": [], "text": input_text, "title": None}
predicted_labels = set()
for sentence in sentences:
for prediction in sentence.get_labels():
spacy_display["ents"].append(
{"start": prediction.data_point.start_position + sentence.start_position,
"end": prediction.data_point.end_position + sentence.start_position,
"label": prediction.value})
predicted_labels.add(prediction.value)
# create colors for each label
colors = {}
for label in predicted_labels:
colors[label] = color_variant(ColorHash(label).hex, brightness_offset=85)
# use displacy to render
html = spacy.displacy.render(spacy_display,
style="ent",
minify=True,
manual=True,
options={
"colors": colors,
},
)
style = "<style>mark.entity { display: inline-block }</style>"
st.subheader("Found entities")
st.write(f"{style}{get_html(html)}", unsafe_allow_html=True)