daniel-de-leon commited on
Commit
1683bca
1 Parent(s): 534b64b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -63
app.py CHANGED
@@ -1,63 +1,63 @@
1
- import streamlit as st
2
- import streamlit.components.v1 as components
3
- from transformers import (AutoModelForSequenceClassification, AutoTokenizer,
4
- pipeline)
5
- import shap
6
- from PIL import Image
7
-
8
- st.set_option('deprecation.showPyplotGlobalUse', False)
9
- output_width = 800
10
- output_height = 300
11
- rescale_logits = False
12
-
13
-
14
-
15
- st.set_page_config(page_title='Text Classification with Shap')
16
- logo = Image.open('Intel-logo.png')
17
- st.sidebar.image(logo)
18
- st.title('Interpreting HF Pipeline Text Classification with Shap')
19
-
20
- form = st.sidebar.form("Model Selection")
21
- form.header('Model Selection')
22
-
23
- model_name = form.text_area("Enter the name of the text classification LLM (note: model must be fine-tuned on a text classification task)", value = "Hate-speech-CNERG/bert-base-uncased-hatexplain")
24
- form.form_submit_button("Submit")
25
-
26
-
27
- @st.cache_data()
28
- def load_model(model_name):
29
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
30
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
31
-
32
- return tokenizer, model
33
-
34
- tokenizer, model = load_model(model_name)
35
- pred = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None)
36
- explainer = shap.Explainer(pred, rescale_to_logits = rescale_logits)
37
-
38
- col1, col2 = st.columns(2)
39
- text = col1.text_area("Enter text input", value = "Classify me.")
40
-
41
- result = pred(text)
42
- top_pred = result[0][0]['label']
43
- col2.write('')
44
- for label in result[0]:
45
- col2.write(f'**{label["label"]}**: {label["score"]: .2f}')
46
-
47
- shap_values = explainer([text])
48
-
49
- force_plot = shap.plots.text(shap_values, display=False)
50
- bar_plot = shap.plots.bar(shap_values[0, :, top_pred], order=shap.Explanation.argsort.flip, show=False)
51
-
52
- st.markdown("""
53
- <style>
54
- .big-font {
55
- font-size:35px !important;
56
- }
57
- </style>
58
- """, unsafe_allow_html=True)
59
- st.markdown(f'<center><p class="big-font">Shap Bar Plot for <i>{top_pred}</i> Prediction</p></center>', unsafe_allow_html=True)
60
- st.pyplot(bar_plot, clear_figure=True)
61
-
62
- st.markdown('<center><p class="big-font">Shap Interactive Force Plot</p></center>', unsafe_allow_html=True)
63
- components.html(force_plot, height=output_height, width=output_width, scrolling=True)
 
1
+ import streamlit as st
2
+ import streamlit.components.v1 as components
3
+ from transformers import (AutoModelForSequenceClassification, AutoTokenizer,
4
+ pipeline)
5
+ import shap
6
+ from PIL import Image
7
+
8
+ st.set_option('deprecation.showPyplotGlobalUse', False)
9
+ output_width = 800
10
+ output_height = 300
11
+ rescale_logits = False
12
+
13
+
14
+
15
+ st.set_page_config(page_title='Text Classification with Shap')
16
+ logo = Image.open('Intel-logo.png')
17
+ st.sidebar.image(logo)
18
+ st.title('Interpreting HF Pipeline Text Classification with Shap')
19
+
20
+ form = st.sidebar.form("Model Selection")
21
+ form.header('Model Selection')
22
+
23
+ model_name = form.text_input("Enter the name of the text classification LLM (note: model must be fine-tuned on a text classification task)", value = "Hate-speech-CNERG/bert-base-uncased-hatexplain")
24
+ form.form_submit_button("Submit")
25
+
26
+
27
+ @st.cache_data()
28
+ def load_model(model_name):
29
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
30
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
31
+
32
+ return tokenizer, model
33
+
34
+ tokenizer, model = load_model(model_name)
35
+ pred = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None)
36
+ explainer = shap.Explainer(pred, rescale_to_logits = rescale_logits)
37
+
38
+ col1, col2 = st.columns(2)
39
+ text = col1.text_area("Enter text input", value = "Classify me.")
40
+
41
+ result = pred(text)
42
+ top_pred = result[0][0]['label']
43
+ col2.write('')
44
+ for label in result[0]:
45
+ col2.write(f'**{label["label"]}**: {label["score"]: .2f}')
46
+
47
+ shap_values = explainer([text])
48
+
49
+ force_plot = shap.plots.text(shap_values, display=False)
50
+ bar_plot = shap.plots.bar(shap_values[0, :, top_pred], order=shap.Explanation.argsort.flip, show=False)
51
+
52
+ st.markdown("""
53
+ <style>
54
+ .big-font {
55
+ font-size:35px !important;
56
+ }
57
+ </style>
58
+ """, unsafe_allow_html=True)
59
+ st.markdown(f'<center><p class="big-font">Shap Bar Plot for <i>{top_pred}</i> Prediction</p></center>', unsafe_allow_html=True)
60
+ st.pyplot(bar_plot, clear_figure=True)
61
+
62
+ st.markdown('<center><p class="big-font">Shap Interactive Force Plot</p></center>', unsafe_allow_html=True)
63
+ components.html(force_plot, height=output_height, width=output_width, scrolling=True)