chumpblocckami commited on
Commit
1b878d8
·
1 Parent(s): b596e88

feat: add application file

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +53 -0
  3. requirements.txt +3 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Companies NER
3
- emoji: 👁
4
  colorFrom: gray
5
  colorTo: indigo
6
  sdk: streamlit
 
1
  ---
2
  title: Companies NER
3
+ emoji: 💻
4
  colorFrom: gray
5
  colorTo: indigo
6
  sdk: streamlit
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from annotated_text import annotated_text
3
+ import transformers
4
+
5
+ ENTITY_TO_COLOR = {
6
+ 'PER': '#8ef',
7
+ 'LOC': '#faa',
8
+ 'ORG': '#afa',
9
+ 'MISC': '#fea',
10
+ }
11
+
12
+ @st.cache(allow_output_mutation=True, show_spinner=False)
13
+ def get_pipe():
14
+ model_name = "dslim/bert-base-NER"
15
+ model = transformers.AutoModelForTokenClassification.from_pretrained(model_name)
16
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
17
+ pipe = transformers.pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
18
+ return pipe
19
+
20
+ def parse_text(text, prediction):
21
+ start = 0
22
+ parsed_text = []
23
+ for p in prediction:
24
+ parsed_text.append(text[start:p["start"]])
25
+ parsed_text.append((p["word"], p["entity_group"], ENTITY_TO_COLOR[p["entity_group"]]))
26
+ start = p["end"]
27
+ parsed_text.append(text[start:])
28
+ return parsed_text
29
+
30
+ st.set_page_config(page_title="Named Entity Recognition")
31
+ st.title("Named Entity Recognition")
32
+ st.write("Type text into the text box and then press 'Predict' to get the named entities.")
33
+
34
+ default_text = "My name is John Smith. I work at Microsoft. I live in Paris. My favorite painting is the Mona Lisa."
35
+
36
+ text = st.text_area('Enter text here:', value=default_text)
37
+ submit = st.button('Predict')
38
+
39
+ with st.spinner("Loading model..."):
40
+ pipe = get_pipe()
41
+
42
+ if (submit and len(text.strip()) > 0) or len(text.strip()) > 0:
43
+
44
+ prediction = pipe(text)
45
+
46
+ parsed_text = parse_text(text, prediction)
47
+
48
+ st.header("Prediction:")
49
+ annotated_text(*parsed_text)
50
+
51
+ st.header('Raw values:')
52
+ st.json(prediction)
53
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers
2
+ torch
3
+ st-annotated-text