File size: 4,710 Bytes
a9dac34 aff6f72 59d626e a9dac34 59d626e a9dac34 59d626e a9dac34 d2f811f 31beafc a9dac34 aff6f72 a9dac34 d2f811f 59d626e a9dac34 59d626e a9dac34 59d626e aff6f72 59d626e d2f811f a9dac34 d2f811f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import shap
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
# Load model and tokenizer with caching
def load_model():
tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
model = AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
return tokenizer, model
tokenizer, model = load_model()
# Define prediction function
def predict(texts):
processed_texts = []
for text in texts:
processed_texts.append(text if not isinstance(text, list)
else tokenizer.convert_tokens_to_string(text))
inputs = tokenizer(
with torch.no_grad():
outputs = model(**inputs)
return torch.nn.functional.softmax(outputs.logits, dim=-1).numpy()
# Initialize SHAP components
output_names = [model.config.id2label[i] for i in range(5)]
masker = shap.maskers.Text(tokenizer=tokenizer, mask_token=tokenizer.mask_token, collapse_mask_token=True)
explainer = shap.Explainer(predict, masker, output_names=output_names)
# Streamlit UI
st.title("π― BERT Sentiment Analysis with SHAP")
**How it works:**
1. Enter text in the box below
2. See predicted sentiment (1-5 stars)
3. View confidence scores and word-level explanations
text_input = st.text_area("Input Text", placeholder="Enter text to analyze...", height=100)
if st.button("Analyze Sentiment"):
if text_input.strip():
with st.spinner("Analyzing..."):
# Get predictions
probabilities = predict([text_input])[0]
predicted_class = np.argmax(probabilities)
# Display results
st.subheader("π Results")
cols = st.columns(2)
cols[0].metric("Predicted Sentiment", output_names[predicted_class])
with cols[1]:
st.markdown("**Confidence Scores**")
for label, score in zip(output_names, probabilities):
st.progress(float(score), text=f"{label}: {score:.1%}")
# Generate SHAP explanations
st.subheader("π Explanation")
**Feature importance (word-level impacts)**
π΄ Higher positive values β Increases sentiment
π΅ Lower negative values β Decreases sentiment
# Get SHAP values for the input text
shap_values = explainer([text_input])
# Create tabs for each sentiment class
tabs = st.tabs(output_names)
for i, tab in enumerate(tabs):
with tab:
# Extract the values and corresponding tokens for our single example.
# shap_values is of shape (1, num_tokens, num_classes)
values = shap_values.values[0, :, i] # SHAP values for class i
tokens =[0] # Tokenized words
# Create a DataFrame to sort and plot the tokens by importance
df = pd.DataFrame({"token": tokens, "shap_value": values})
# Sort tokens by the absolute SHAP value (smallest at the bottom for horizontal bar plot)
df = df.sort_values("shap_value", key=lambda x: np.abs(x), ascending=True)
# Create a horizontal bar plot
fig, ax = plt.subplots(figsize=(8, max(4, len(tokens) * 0.3)))
ax.barh(df["token"], df["shap_value"], color='skyblue')
ax.set_xlabel("SHAP value")
ax.set_title(f"SHAP bar plot for class '{output_names[i]}'")
st.warning("Please enter some text to analyze")
st.markdown("Example texts to try:")
examples = st.columns(4)
example_texts = [
"This product exceeded all my expectations!",
"Terrible customer service experience.",
"The movie was okay, nothing special.",
"You are kinda cool"
for col, text in zip(examples, example_texts):
with col:
if st.button(text, use_container_width=True):
st.session_state.last_input = text
if 'last_input' in st.session_state:
text_input = st.text_area("", value=st.session_state.last_input, height=100)