|
import shap |
|
import numpy as np |
|
import torch |
|
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer |
|
import streamlit as st |
|
|
|
model_name = "mavinsao/mi-roberta-base-finetuned-mental-health" |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
|
|
classifier = pipeline("text-classification", model=model, tokenizer=tokenizer) |
|
|
|
|
|
st.title("SHAP Explanation for Mental Illness Prediction") |
|
|
|
|
|
text = st.text_area("Enter a sentence to explain:") |
|
|
|
if st.button("Explain"): |
|
|
|
explainer = shap.Explainer(classifier, masker=tokenizer) |
|
|
|
|
|
shap_values = explainer([text]) |
|
|
|
|
|
shap_html = shap.plots.text(shap_values, display=False) |
|
|
|
|
|
with open("shap_plot.html", "w") as f: |
|
f.write(shap_html) |
|
|
|
|
|
st.components.v1.html(shap_html, height=500, scrolling=True) |
|
|