Ashed00 commited on
Commit
d2f811f
·
verified ·
1 Parent(s): 1591b0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -10
app.py CHANGED
@@ -2,8 +2,8 @@ import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import shap
4
  import torch
5
- import matplotlib.pyplot as plt
6
  import numpy as np
 
7
 
8
  # Load model and tokenizer with caching
9
  @st.cache_resource
@@ -65,7 +65,7 @@ if st.button("Analyze Sentiment"):
65
 
66
  with cols[1]:
67
  st.markdown("**Confidence Scores**")
68
- for i, (label, score) in enumerate(zip(output_names, probabilities)):
69
  st.progress(float(score), text=f"{label}: {score:.1%}")
70
 
71
  # Generate SHAP explanations
@@ -75,18 +75,27 @@ if st.button("Analyze Sentiment"):
75
  Red → Increases score | Blue → Decreases score
76
  Intensity shows magnitude of impact
77
  """)
78
-
79
  shap_values = explainer([text_input])
80
-
81
  # Create tabs for each sentiment class
82
  tabs = st.tabs(output_names)
83
  for i, tab in enumerate(tabs):
84
  with tab:
85
- # Create a new matplotlib figure
86
- plt.figure()
87
- shap.plots.text(shap_values[:, :, i],) # Generate SHAP plot
88
- st.pyplot(plt.gcf()) # Pass the current figure to st.pyplot
89
- plt.close() # Close the figure to free memory
 
 
 
 
 
 
 
 
 
90
  else:
91
  st.warning("Please enter some text to analyze")
92
 
@@ -106,4 +115,4 @@ for col, text in zip(examples, example_texts):
106
  st.session_state.last_input = text
107
 
108
  if 'last_input' in st.session_state:
109
- text_input = st.text_area("", value=st.session_state.last_input, height=100)
 
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import shap
4
  import torch
 
5
  import numpy as np
6
+ import os
7
 
8
  # Load model and tokenizer with caching
9
  @st.cache_resource
 
65
 
66
  with cols[1]:
67
  st.markdown("**Confidence Scores**")
68
+ for label, score in zip(output_names, probabilities):
69
  st.progress(float(score), text=f"{label}: {score:.1%}")
70
 
71
  # Generate SHAP explanations
 
75
  Red → Increases score | Blue → Decreases score
76
  Intensity shows magnitude of impact
77
  """)
78
+
79
  shap_values = explainer([text_input])
80
+
81
  # Create tabs for each sentiment class
82
  tabs = st.tabs(output_names)
83
  for i, tab in enumerate(tabs):
84
  with tab:
85
+ # Save SHAP text explanation as an HTML file
86
+ html_path = f"shap_explanation_{i}.html"
87
+ with open(html_path, "w", encoding="utf-8") as file:
88
+ file.write(shap.plots.text(shap_values[:, :, i], display=False))
89
+
90
+ # Read and display the saved HTML file in Streamlit
91
+ with open(html_path, "r", encoding="utf-8") as file:
92
+ shap_html = file.read()
93
+
94
+ st.components.v1.html(shap_html, height=400, scrolling=True)
95
+
96
+ # Clean up temporary HTML files (optional)
97
+ os.remove(html_path)
98
+
99
  else:
100
  st.warning("Please enter some text to analyze")
101
 
 
115
  st.session_state.last_input = text
116
 
117
  if 'last_input' in st.session_state:
118
+ text_input = st.text_area("", value=st.session_state.last_input, height=100)