Ashed00's picture
Update app.py
59d626e verified
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import shap
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
# Load model and tokenizer with caching
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
model = AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
return tokenizer, model
tokenizer, model = load_model()
# Define prediction function
def predict(texts):
processed_texts = []
for text in texts:
processed_texts.append(text if not isinstance(text, list)
else tokenizer.convert_tokens_to_string(text))
inputs = tokenizer(
processed_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512,
add_special_tokens=True
)
with torch.no_grad():
outputs = model(**inputs)
return torch.nn.functional.softmax(outputs.logits, dim=-1).numpy()
# Initialize SHAP components
output_names = [model.config.id2label[i] for i in range(5)]
masker = shap.maskers.Text(tokenizer=tokenizer, mask_token=tokenizer.mask_token, collapse_mask_token=True)
explainer = shap.Explainer(predict, masker, output_names=output_names)
# Streamlit UI
st.title("🎯 BERT Sentiment Analysis with SHAP")
st.markdown("""
**How it works:**
1. Enter text in the box below
2. See predicted sentiment (1-5 stars)
3. View confidence scores and word-level explanations
""")
text_input = st.text_area("Input Text", placeholder="Enter text to analyze...", height=100)
if st.button("Analyze Sentiment"):
if text_input.strip():
with st.spinner("Analyzing..."):
# Get predictions
probabilities = predict([text_input])[0]
predicted_class = np.argmax(probabilities)
# Display results
st.subheader("πŸ“Š Results")
cols = st.columns(2)
cols[0].metric("Predicted Sentiment", output_names[predicted_class])
with cols[1]:
st.markdown("**Confidence Scores**")
for label, score in zip(output_names, probabilities):
st.progress(float(score), text=f"{label}: {score:.1%}")
# Generate SHAP explanations
st.subheader("πŸ” Explanation")
st.markdown("""
**Feature importance (word-level impacts)**
πŸ”΄ Higher positive values β†’ Increases sentiment
πŸ”΅ Lower negative values β†’ Decreases sentiment
""")
# Get SHAP values for the input text
shap_values = explainer([text_input])
# Create tabs for each sentiment class
tabs = st.tabs(output_names)
for i, tab in enumerate(tabs):
with tab:
# Extract the values and corresponding tokens for our single example.
# shap_values is of shape (1, num_tokens, num_classes)
values = shap_values.values[0, :, i] # SHAP values for class i
tokens = shap_values.data[0] # Tokenized words
# Create a DataFrame to sort and plot the tokens by importance
df = pd.DataFrame({"token": tokens, "shap_value": values})
# Sort tokens by the absolute SHAP value (smallest at the bottom for horizontal bar plot)
df = df.sort_values("shap_value", key=lambda x: np.abs(x), ascending=True)
# Create a horizontal bar plot
fig, ax = plt.subplots(figsize=(8, max(4, len(tokens) * 0.3)))
ax.barh(df["token"], df["shap_value"], color='skyblue')
ax.set_xlabel("SHAP value")
ax.set_title(f"SHAP bar plot for class '{output_names[i]}'")
st.pyplot(fig)
plt.close(fig)
else:
st.warning("Please enter some text to analyze")
st.markdown("---")
st.markdown("Example texts to try:")
examples = st.columns(4)
example_texts = [
"This product exceeded all my expectations!",
"Terrible customer service experience.",
"The movie was okay, nothing special.",
"You are kinda cool"
]
for col, text in zip(examples, example_texts):
with col:
if st.button(text, use_container_width=True):
st.session_state.last_input = text
if 'last_input' in st.session_state:
text_input = st.text_area("", value=st.session_state.last_input, height=100)