Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import shap | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
# Load model and tokenizer with caching | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment") | |
model = AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment") | |
return tokenizer, model | |
tokenizer, model = load_model() | |
# Define prediction function | |
def predict(texts): | |
processed_texts = [] | |
for text in texts: | |
processed_texts.append(text if not isinstance(text, list) | |
else tokenizer.convert_tokens_to_string(text)) | |
inputs = tokenizer( | |
processed_texts, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=512, | |
add_special_tokens=True | |
) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
return torch.nn.functional.softmax(outputs.logits, dim=-1).numpy() | |
# Initialize SHAP components | |
output_names = [model.config.id2label[i] for i in range(5)] | |
masker = shap.maskers.Text(tokenizer=tokenizer, mask_token=tokenizer.mask_token, collapse_mask_token=True) | |
explainer = shap.Explainer(predict, masker, output_names=output_names) | |
# Streamlit UI | |
st.title("π― BERT Sentiment Analysis with SHAP") | |
st.markdown(""" | |
**How it works:** | |
1. Enter text in the box below | |
2. See predicted sentiment (1-5 stars) | |
3. View confidence scores and word-level explanations | |
""") | |
text_input = st.text_area("Input Text", placeholder="Enter text to analyze...", height=100) | |
if st.button("Analyze Sentiment"): | |
if text_input.strip(): | |
with st.spinner("Analyzing..."): | |
# Get predictions | |
probabilities = predict([text_input])[0] | |
predicted_class = np.argmax(probabilities) | |
# Display results | |
st.subheader("π Results") | |
cols = st.columns(2) | |
cols[0].metric("Predicted Sentiment", output_names[predicted_class]) | |
with cols[1]: | |
st.markdown("**Confidence Scores**") | |
for label, score in zip(output_names, probabilities): | |
st.progress(float(score), text=f"{label}: {score:.1%}") | |
# Generate SHAP explanations | |
st.subheader("π Explanation") | |
st.markdown(""" | |
**Feature importance (word-level impacts)** | |
π΄ Higher positive values β Increases sentiment | |
π΅ Lower negative values β Decreases sentiment | |
""") | |
# Get SHAP values for the input text | |
shap_values = explainer([text_input]) | |
# Create tabs for each sentiment class | |
tabs = st.tabs(output_names) | |
for i, tab in enumerate(tabs): | |
with tab: | |
# Extract the values and corresponding tokens for our single example. | |
# shap_values is of shape (1, num_tokens, num_classes) | |
values = shap_values.values[0, :, i] # SHAP values for class i | |
tokens = shap_values.data[0] # Tokenized words | |
# Create a DataFrame to sort and plot the tokens by importance | |
df = pd.DataFrame({"token": tokens, "shap_value": values}) | |
# Sort tokens by the absolute SHAP value (smallest at the bottom for horizontal bar plot) | |
df = df.sort_values("shap_value", key=lambda x: np.abs(x), ascending=True) | |
# Create a horizontal bar plot | |
fig, ax = plt.subplots(figsize=(8, max(4, len(tokens) * 0.3))) | |
ax.barh(df["token"], df["shap_value"], color='skyblue') | |
ax.set_xlabel("SHAP value") | |
ax.set_title(f"SHAP bar plot for class '{output_names[i]}'") | |
st.pyplot(fig) | |
plt.close(fig) | |
else: | |
st.warning("Please enter some text to analyze") | |
st.markdown("---") | |
st.markdown("Example texts to try:") | |
examples = st.columns(4) | |
example_texts = [ | |
"This product exceeded all my expectations!", | |
"Terrible customer service experience.", | |
"The movie was okay, nothing special.", | |
"You are kinda cool" | |
] | |
for col, text in zip(examples, example_texts): | |
with col: | |
if st.button(text, use_container_width=True): | |
st.session_state.last_input = text | |
if 'last_input' in st.session_state: | |
text_input = st.text_area("", value=st.session_state.last_input, height=100) | |