import streamlit as st import streamlit.components.v1 as components from transformers import (AutoModelForSequenceClassification, AutoTokenizer, pipeline) import shap output_width = 800 output_height = 1000 rescale_logits = False st.set_page_config(page_title='Text Classification with Shap') st.title('Interpreting HF Pipeline Text Classification with Shap') text = st.text_area("Enter text input", value = "Classify me.") form = st.sidebar.form("Main Settings") form.header('Main Settings') model_name = form.text_area("Enter the name of the text classification model", value = "Hate-speech-CNERG/bert-base-uncased-hatexplain") form.form_submit_button("Submit") @st.cache_data() def load_model(model_name): tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) model = AutoModelForSequenceClassification.from_pretrained(model_name) return tokenizer, model tokenizer, model = load_model(model_name) pred = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None) explainer = shap.Explainer(pred, rescale_to_logits = rescale_logits) shap_values = explainer([text]) shap_plot = shap.plots.text(shap_values, display=False) st.title('Interactive Shap Force Plot') components.html(shap_plot, height=output_height, width=output_width, scrolling=True)