daniel-de-leon commited on
Commit
6a675a5
·
1 Parent(s): d4f4731

add basic shap text classification

Browse files
Files changed (2) hide show
  1. app.py +36 -2
  2. requirements.txt +2 -0
app.py CHANGED
@@ -1,4 +1,38 @@
1
  import streamlit as st
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import streamlit.components.v1 as components
3
+ from transformers import (AutoModelForSequenceClassification, AutoTokenizer,
4
+ pipeline)
5
+ import shap
6
 
7
+ output_width = 800
8
+ output_height = 1000
9
+ rescale_logits = False
10
+
11
+ st.set_page_config(page_title='Text Classification with Shap')
12
+ st.title('Interpreting HF Pipeline Text Classification with Shap')
13
+
14
+ text = st.text_area("Enter text input", value = "Classify me.")
15
+
16
+ form = st.sidebar.form("Main Settings")
17
+ form.header('Main Settings')
18
+
19
+ model_name = form.text_area("Enter the name of the text classification model", value = "Hate-speech-CNERG/bert-base-uncased-hatexplain")
20
+ form.form_submit_button("Submit")
21
+
22
+
23
+ @st.cache_data()
24
+ def load_model(model_name):
25
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
26
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
27
+
28
+ return tokenizer, model
29
+
30
+ tokenizer, model = load_model(model_name)
31
+ pred = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None)
32
+ explainer = shap.Explainer(pred, rescale_to_logits = rescale_logits)
33
+
34
+ shap_values = explainer([text])
35
+
36
+ shap_plot = shap.plots.text(shap_values, display=False)
37
+ st.title('Interactive Shap Force Plot')
38
+ components.html(shap_plot, height=output_height, width=output_width, scrolling=True)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers
2
+ shap