daniel-de-leon commited on
Commit
534b64b
1 Parent(s): 62cbb06

add shap bar plot

Browse files
Files changed (2) hide show
  1. Intel-logo.png +0 -0
  2. app.py +34 -9
Intel-logo.png ADDED
app.py CHANGED
@@ -3,20 +3,24 @@ import streamlit.components.v1 as components
3
  from transformers import (AutoModelForSequenceClassification, AutoTokenizer,
4
  pipeline)
5
  import shap
 
6
 
 
7
  output_width = 800
8
- output_height = 1000
9
  rescale_logits = False
10
 
 
 
11
  st.set_page_config(page_title='Text Classification with Shap')
 
 
12
  st.title('Interpreting HF Pipeline Text Classification with Shap')
13
 
14
- text = st.text_area("Enter text input", value = "Classify me.")
15
-
16
- form = st.sidebar.form("Main Settings")
17
- form.header('Main Settings')
18
 
19
- model_name = form.text_area("Enter the name of the text classification model", value = "Hate-speech-CNERG/bert-base-uncased-hatexplain")
20
  form.form_submit_button("Submit")
21
 
22
 
@@ -31,8 +35,29 @@ tokenizer, model = load_model(model_name)
31
  pred = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None)
32
  explainer = shap.Explainer(pred, rescale_to_logits = rescale_logits)
33
 
 
 
 
 
 
 
 
 
 
34
  shap_values = explainer([text])
35
 
36
- shap_plot = shap.plots.text(shap_values, display=False)
37
- st.title('Interactive Shap Force Plot')
38
- components.html(shap_plot, height=output_height, width=output_width, scrolling=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from transformers import (AutoModelForSequenceClassification, AutoTokenizer,
4
  pipeline)
5
  import shap
6
+ from PIL import Image
7
 
8
+ st.set_option('deprecation.showPyplotGlobalUse', False)
9
  output_width = 800
10
+ output_height = 300
11
  rescale_logits = False
12
 
13
+
14
+
15
  st.set_page_config(page_title='Text Classification with Shap')
16
+ logo = Image.open('Intel-logo.png')
17
+ st.sidebar.image(logo)
18
  st.title('Interpreting HF Pipeline Text Classification with Shap')
19
 
20
+ form = st.sidebar.form("Model Selection")
21
+ form.header('Model Selection')
 
 
22
 
23
+ model_name = form.text_area("Enter the name of the text classification LLM (note: model must be fine-tuned on a text classification task)", value = "Hate-speech-CNERG/bert-base-uncased-hatexplain")
24
  form.form_submit_button("Submit")
25
 
26
 
 
35
  pred = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None)
36
  explainer = shap.Explainer(pred, rescale_to_logits = rescale_logits)
37
 
38
+ col1, col2 = st.columns(2)
39
+ text = col1.text_area("Enter text input", value = "Classify me.")
40
+
41
+ result = pred(text)
42
+ top_pred = result[0][0]['label']
43
+ col2.write('')
44
+ for label in result[0]:
45
+ col2.write(f'**{label["label"]}**: {label["score"]: .2f}')
46
+
47
  shap_values = explainer([text])
48
 
49
+ force_plot = shap.plots.text(shap_values, display=False)
50
+ bar_plot = shap.plots.bar(shap_values[0, :, top_pred], order=shap.Explanation.argsort.flip, show=False)
51
+
52
+ st.markdown("""
53
+ <style>
54
+ .big-font {
55
+ font-size:35px !important;
56
+ }
57
+ </style>
58
+ """, unsafe_allow_html=True)
59
+ st.markdown(f'<center><p class="big-font">Shap Bar Plot for <i>{top_pred}</i> Prediction</p></center>', unsafe_allow_html=True)
60
+ st.pyplot(bar_plot, clear_figure=True)
61
+
62
+ st.markdown('<center><p class="big-font">Shap Interactive Force Plot</p></center>', unsafe_allow_html=True)
63
+ components.html(force_plot, height=output_height, width=output_width, scrolling=True)