companies_NER / app.py
chumpblocckami's picture
feat: added models choice
a9cf63c
raw
history blame
2.03 kB
import streamlit as st
import transformers
from annotated_text import annotated_text
ENTITY_TO_COLOR = {
'PER': '#8ef',
'LOC': '#faa',
'ORG': '#afa',
'MISC': '#fea',
}
@st.cache(allow_output_mutation=True, show_spinner=False)
def get_pipe(model_name):
model = transformers.AutoModelForTokenClassification.from_pretrained(model_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
pipe = transformers.pipeline("token-classification",
model=model,
tokenizer=tokenizer,
aggregation_strategy="simple")
return pipe
def parse_text(text, prediction):
start = 0
parsed_text = []
for p in prediction:
parsed_text.append(text[start:p["start"]])
parsed_text.append((p["word"], p["entity_group"], ENTITY_TO_COLOR[p["entity_group"]]))
start = p["end"]
parsed_text.append(text[start:])
return parsed_text
st.set_page_config(page_title="Named Entity Recognition")
st.title("Named Entity Recognition")
st.write("Type text into the text box and then press 'Predict' to get the named entities.")
option = st.selectbox('Select model',
("dslim/bert-base-NER",
"dslim/bert-large-NER",
"Davlan/bert-base-multilingual-cased-ner-hrl"))
default_text = "Xbox v PlayStation: Giants clash over Call of Duty: Xbox owner Microsoft has hit back at claims its plan to buy the maker of Call of Duty may unfairly affect its rivals, including Sony, which owns PlayStation."
text = st.text_area('Enter text here:', value=default_text)
st.write('Model used for prediction:', option)
submit = st.button('Predict')
with st.spinner("Loading model..."):
pipe = get_pipe(model_name=option)
if (submit and len(text.strip()) > 0) or len(text.strip()) > 0:
prediction = pipe(text)
parsed_text = parse_text(text, prediction)
st.header("Prediction:")
annotated_text(*parsed_text)