Ashed00 commited on
Commit
59d626e
Β·
verified Β·
1 Parent(s): aff6f72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -13
app.py CHANGED
@@ -4,6 +4,7 @@ import shap
4
  import torch
5
  import numpy as np
6
  import matplotlib.pyplot as plt
 
7
 
8
  # Load model and tokenizer with caching
9
  @st.cache_resource
@@ -18,8 +19,8 @@ tokenizer, model = load_model()
18
  def predict(texts):
19
  processed_texts = []
20
  for text in texts:
21
- processed_texts.append(text if not isinstance(text, list)
22
- else tokenizer.convert_tokens_to_string(text))
23
 
24
  inputs = tokenizer(
25
  processed_texts,
@@ -44,9 +45,9 @@ explainer = shap.Explainer(predict, masker, output_names=output_names)
44
  st.title("🎯 BERT Sentiment Analysis with SHAP")
45
  st.markdown("""
46
  **How it works:**
47
- 1. Enter text in the box below
48
- 2. See predicted sentiment (1-5 stars)
49
- 3. View confidence scores and word-level explanations
50
  """)
51
 
52
  text_input = st.text_area("Input Text", placeholder="Enter text to analyze...", height=100)
@@ -62,7 +63,6 @@ if st.button("Analyze Sentiment"):
62
  st.subheader("πŸ“Š Results")
63
  cols = st.columns(2)
64
  cols[0].metric("Predicted Sentiment", output_names[predicted_class])
65
-
66
  with cols[1]:
67
  st.markdown("**Confidence Scores**")
68
  for label, score in zip(output_names, probabilities):
@@ -76,19 +76,30 @@ if st.button("Analyze Sentiment"):
76
  πŸ”΅ Lower negative values β†’ Decreases sentiment
77
  """)
78
 
 
79
  shap_values = explainer([text_input])
80
-
81
  # Create tabs for each sentiment class
82
  tabs = st.tabs(output_names)
83
  for i, tab in enumerate(tabs):
84
  with tab:
85
- # Create a bar plot of SHAP values
86
- fig, ax = plt.subplots(figsize=(8, 4))
87
- shap.plots.bar(shap_values[:, :, i], show=False)
88
-
89
- # Display the plot in Streamlit
 
 
 
 
 
 
 
 
 
 
90
  st.pyplot(fig)
91
- plt.close(fig) # Free memory after rendering
92
 
93
  else:
94
  st.warning("Please enter some text to analyze")
 
4
  import torch
5
  import numpy as np
6
  import matplotlib.pyplot as plt
7
+ import pandas as pd
8
 
9
  # Load model and tokenizer with caching
10
  @st.cache_resource
 
19
  def predict(texts):
20
  processed_texts = []
21
  for text in texts:
22
+ processed_texts.append(text if not isinstance(text, list)
23
+ else tokenizer.convert_tokens_to_string(text))
24
 
25
  inputs = tokenizer(
26
  processed_texts,
 
45
  st.title("🎯 BERT Sentiment Analysis with SHAP")
46
  st.markdown("""
47
  **How it works:**
48
+ 1. Enter text in the box below
49
+ 2. See predicted sentiment (1-5 stars)
50
+ 3. View confidence scores and word-level explanations
51
  """)
52
 
53
  text_input = st.text_area("Input Text", placeholder="Enter text to analyze...", height=100)
 
63
  st.subheader("πŸ“Š Results")
64
  cols = st.columns(2)
65
  cols[0].metric("Predicted Sentiment", output_names[predicted_class])
 
66
  with cols[1]:
67
  st.markdown("**Confidence Scores**")
68
  for label, score in zip(output_names, probabilities):
 
76
  πŸ”΅ Lower negative values β†’ Decreases sentiment
77
  """)
78
 
79
+ # Get SHAP values for the input text
80
  shap_values = explainer([text_input])
81
+
82
  # Create tabs for each sentiment class
83
  tabs = st.tabs(output_names)
84
  for i, tab in enumerate(tabs):
85
  with tab:
86
+ # Extract the values and corresponding tokens for our single example.
87
+ # shap_values is of shape (1, num_tokens, num_classes)
88
+ values = shap_values.values[0, :, i] # SHAP values for class i
89
+ tokens = shap_values.data[0] # Tokenized words
90
+
91
+ # Create a DataFrame to sort and plot the tokens by importance
92
+ df = pd.DataFrame({"token": tokens, "shap_value": values})
93
+ # Sort tokens by the absolute SHAP value (smallest at the bottom for horizontal bar plot)
94
+ df = df.sort_values("shap_value", key=lambda x: np.abs(x), ascending=True)
95
+
96
+ # Create a horizontal bar plot
97
+ fig, ax = plt.subplots(figsize=(8, max(4, len(tokens) * 0.3)))
98
+ ax.barh(df["token"], df["shap_value"], color='skyblue')
99
+ ax.set_xlabel("SHAP value")
100
+ ax.set_title(f"SHAP bar plot for class '{output_names[i]}'")
101
  st.pyplot(fig)
102
+ plt.close(fig)
103
 
104
  else:
105
  st.warning("Please enter some text to analyze")