daniel-de-leon's picture
add shap bar plot
534b64b
raw
history blame
2.23 kB
import streamlit as st
import streamlit.components.v1 as components
from transformers import (AutoModelForSequenceClassification, AutoTokenizer,
pipeline)
import shap
from PIL import Image
st.set_option('deprecation.showPyplotGlobalUse', False)
output_width = 800
output_height = 300
rescale_logits = False
st.set_page_config(page_title='Text Classification with Shap')
logo = Image.open('Intel-logo.png')
st.sidebar.image(logo)
st.title('Interpreting HF Pipeline Text Classification with Shap')
form = st.sidebar.form("Model Selection")
form.header('Model Selection')
model_name = form.text_area("Enter the name of the text classification LLM (note: model must be fine-tuned on a text classification task)", value = "Hate-speech-CNERG/bert-base-uncased-hatexplain")
form.form_submit_button("Submit")
@st.cache_data()
def load_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
return tokenizer, model
tokenizer, model = load_model(model_name)
pred = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None)
explainer = shap.Explainer(pred, rescale_to_logits = rescale_logits)
col1, col2 = st.columns(2)
text = col1.text_area("Enter text input", value = "Classify me.")
result = pred(text)
top_pred = result[0][0]['label']
col2.write('')
for label in result[0]:
col2.write(f'**{label["label"]}**: {label["score"]: .2f}')
shap_values = explainer([text])
force_plot = shap.plots.text(shap_values, display=False)
bar_plot = shap.plots.bar(shap_values[0, :, top_pred], order=shap.Explanation.argsort.flip, show=False)
st.markdown("""
<style>
.big-font {
font-size:35px !important;
}
</style>
""", unsafe_allow_html=True)
st.markdown(f'<center><p class="big-font">Shap Bar Plot for <i>{top_pred}</i> Prediction</p></center>', unsafe_allow_html=True)
st.pyplot(bar_plot, clear_figure=True)
st.markdown('<center><p class="big-font">Shap Interactive Force Plot</p></center>', unsafe_allow_html=True)
components.html(force_plot, height=output_height, width=output_width, scrolling=True)