import shap import numpy as np import torch from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer import streamlit as st model_name = "mavinsao/mi-roberta-base-finetuned-mental-health" model = AutoModelForSequenceClassification.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Create a pipeline with the model and tokenizer classifier = pipeline("text-classification", model=model, tokenizer=tokenizer) # Streamlit app st.title("SHAP Explanation for Mental Illness Prediction") # Input text area for user input text = st.text_area("Enter a sentence to explain:") if st.button("Explain"): # Generate the SHAP explainer explainer = shap.Explainer(classifier, masker=tokenizer) # Compute SHAP values shap_values = explainer([text]) # Save SHAP plot as HTML shap_html = shap.plots.text(shap_values, display=False) # Save the plot to an HTML file with open("shap_plot.html", "w") as f: f.write(shap_html) # Display the SHAP plot in Streamlit using components st.components.v1.html(shap_html, height=500, scrolling=True)