import streamlit as st import transformers from annotated_text import annotated_text ENTITY_TO_COLOR = { 'PER': '#8ef', 'LOC': '#faa', 'ORG': '#afa', 'MISC': '#fea', } @st.cache(allow_output_mutation=True, show_spinner=False) def get_pipe(model_name): model = transformers.AutoModelForTokenClassification.from_pretrained(model_name) tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) pipe = transformers.pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple") return pipe def parse_text(text, prediction): start = 0 parsed_text = [] for p in prediction: parsed_text.append(text[start:p["start"]]) parsed_text.append((p["word"], p["entity_group"], ENTITY_TO_COLOR[p["entity_group"]])) start = p["end"] parsed_text.append(text[start:]) return parsed_text st.set_page_config(page_title="Named Entity Recognition") st.title("Named Entity Recognition") st.write("Type text into the text box and then press 'Predict' to get the named entities.") option = st.selectbox('Select model', ("dslim/bert-base-NER", "dslim/bert-large-NER", "Davlan/bert-base-multilingual-cased-ner-hrl")) default_text = "Xbox v PlayStation: Giants clash over Call of Duty: Xbox owner Microsoft has hit back at claims its plan to buy the maker of Call of Duty may unfairly affect its rivals, including Sony, which owns PlayStation." text = st.text_area('Enter text here:', value=default_text) st.write('Model used for prediction:', option) submit = st.button('Predict') with st.spinner("Loading model..."): pipe = get_pipe(model_name=option) if (submit and len(text.strip()) > 0) or len(text.strip()) > 0: prediction = pipe(text) parsed_text = parse_text(text, prediction) st.header("Prediction:") annotated_text(*parsed_text)