mavinsao's picture
Update app.py
84bcaaf verified
import shap
import numpy as np
import torch
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
import streamlit as st
model_name = "mavinsao/mi-roberta-base-finetuned-mental-health"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Create a pipeline with the model and tokenizer
classifier = pipeline("text-classification", model=model, tokenizer=tokenizer)
# Streamlit app
st.title("SHAP Explanation for Mental Illness Prediction")
# Input text area for user input
text = st.text_area("Enter a sentence to explain:")
if st.button("Explain"):
# Generate the SHAP explainer
explainer = shap.Explainer(classifier, masker=tokenizer)
# Compute SHAP values
shap_values = explainer([text])
# Save SHAP plot as HTML
shap_html = shap.plots.text(shap_values, display=False)
# Save the plot to an HTML file
with open("shap_plot.html", "w") as f:
f.write(shap_html)
# Display the SHAP plot in Streamlit using components
st.components.v1.html(shap_html, height=500, scrolling=True)