mavinsao commited on
Commit
c9b5b01
·
verified ·
1 Parent(s): 74e968f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -0
app.py CHANGED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shap
2
+ import numpy as np
3
+ import torch
4
+ from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
5
+ import streamlit as st
6
+
7
+ model_name = "mavinsao/mi-roberta-base-finetuned-mental-health"
8
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ model.to(device)
13
+
14
+ # Create a pipeline with the model and tokenizer
15
+ classifier = pipeline("text-classification", model=model, tokenizer=tokenizer)
16
+
17
+ # Streamlit app
18
+ st.title("SHAP Explanation for Mental Illness Prediction")
19
+
20
+ # Input text area for user input
21
+ text = st.text_area("Enter a sentence to explain:")
22
+
23
+ if st.button("Explain"):
24
+ # Generate the SHAP explainer
25
+ explainer = shap.Explainer(classifier, masker=tokenizer)
26
+
27
+ # Compute SHAP values
28
+ shap_values = explainer([text])
29
+
30
+ # Save SHAP plot as HTML
31
+ shap_html = shap.plots.text(shap_values, display=False)
32
+
33
+ # Save the plot to an HTML file
34
+ shap_html.save_html("shap_plot.html")
35
+
36
+ # Read the HTML file and display in Streamlit
37
+ with open("shap_plot.html", "r") as f:
38
+ shap_html = f.read()
39
+
40
+ # Display the SHAP plot in Streamlit using components
41
+ st.components.v1.html(shap_html, height=500, scrolling=True)