chumpblocckami commited on
Commit
b51cc8c
·
1 Parent(s): a986644

feat: added models choice

Browse files
Files changed (1) hide show
  1. app.py +13 -8
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import streamlit as st
2
- from annotated_text import annotated_text
3
  import transformers
 
4
 
5
  ENTITY_TO_COLOR = {
6
  'PER': '#8ef',
@@ -9,14 +9,18 @@ ENTITY_TO_COLOR = {
9
  'MISC': '#fea',
10
  }
11
 
 
12
  @st.cache(allow_output_mutation=True, show_spinner=False)
13
- def get_pipe():
14
- model_name = "dslim/bert-base-NER"
15
  model = transformers.AutoModelForTokenClassification.from_pretrained(model_name)
16
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
17
- pipe = transformers.pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
 
 
 
18
  return pipe
19
 
 
20
  def parse_text(text, prediction):
21
  start = 0
22
  parsed_text = []
@@ -27,24 +31,25 @@ def parse_text(text, prediction):
27
  parsed_text.append(text[start:])
28
  return parsed_text
29
 
 
30
  st.set_page_config(page_title="Named Entity Recognition")
31
  st.title("Named Entity Recognition")
32
  st.write("Type text into the text box and then press 'Predict' to get the named entities.")
33
 
34
- default_text = "My name is John Smith. I work at Microsoft. I live in Paris. My favorite painting is the Mona Lisa."
 
35
 
 
36
  text = st.text_area('Enter text here:', value=default_text)
37
  submit = st.button('Predict')
38
 
39
  with st.spinner("Loading model..."):
40
- pipe = get_pipe()
41
 
42
  if (submit and len(text.strip()) > 0) or len(text.strip()) > 0:
43
-
44
  prediction = pipe(text)
45
 
46
  parsed_text = parse_text(text, prediction)
47
 
48
  st.header("Prediction:")
49
  annotated_text(*parsed_text)
50
-
 
1
  import streamlit as st
 
2
  import transformers
3
+ from annotated_text import annotated_text
4
 
5
  ENTITY_TO_COLOR = {
6
  'PER': '#8ef',
 
9
  'MISC': '#fea',
10
  }
11
 
12
+
13
  @st.cache(allow_output_mutation=True, show_spinner=False)
14
+ def get_pipe(model_name):
 
15
  model = transformers.AutoModelForTokenClassification.from_pretrained(model_name)
16
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
17
+ pipe = transformers.pipeline("token-classification",
18
+ model=model,
19
+ tokenizer=tokenizer,
20
+ aggregation_strategy="simple")
21
  return pipe
22
 
23
+
24
  def parse_text(text, prediction):
25
  start = 0
26
  parsed_text = []
 
31
  parsed_text.append(text[start:])
32
  return parsed_text
33
 
34
+
35
  st.set_page_config(page_title="Named Entity Recognition")
36
  st.title("Named Entity Recognition")
37
  st.write("Type text into the text box and then press 'Predict' to get the named entities.")
38
 
39
+ option = st.selectbox('Model', ("dslim/bert-base-NER", "flair/ner-english-fast", "Jean-Baptiste/camembert-ner"))
40
+ st.write('Selected model:', option)
41
 
42
+ default_text = "Xbox v PlayStation: Giants clash over Call of Duty: Xbox owner Microsoft has hit back at claims its plan to buy the maker of Call of Duty may unfairly affect its rivals, including Sony, which owns PlayStation."
43
  text = st.text_area('Enter text here:', value=default_text)
44
  submit = st.button('Predict')
45
 
46
  with st.spinner("Loading model..."):
47
+ pipe = get_pipe(model_name=option)
48
 
49
  if (submit and len(text.strip()) > 0) or len(text.strip()) > 0:
 
50
  prediction = pipe(text)
51
 
52
  parsed_text = parse_text(text, prediction)
53
 
54
  st.header("Prediction:")
55
  annotated_text(*parsed_text)