File size: 4,710 Bytes
a9dac34
 
 
 
 
aff6f72
59d626e
a9dac34
 
 
 
 
 
 
 
 
 
 
 
 
 
59d626e
 
a9dac34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59d626e
 
 
a9dac34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2f811f
31beafc
a9dac34
 
 
 
aff6f72
 
 
a9dac34
d2f811f
59d626e
a9dac34
59d626e
a9dac34
 
 
 
59d626e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aff6f72
59d626e
d2f811f
a9dac34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2f811f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import shap
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# Load model and tokenizer with caching
@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
    model = AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
    return tokenizer, model

tokenizer, model = load_model()

# Define prediction function
def predict(texts):
    processed_texts = []
    for text in texts:
        processed_texts.append(text if not isinstance(text, list)
                               else tokenizer.convert_tokens_to_string(text))
    
    inputs = tokenizer(
        processed_texts, 
        return_tensors="pt", 
        padding=True, 
        truncation=True, 
        max_length=512,
        add_special_tokens=True
    )
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    return torch.nn.functional.softmax(outputs.logits, dim=-1).numpy()

# Initialize SHAP components
output_names = [model.config.id2label[i] for i in range(5)]
masker = shap.maskers.Text(tokenizer=tokenizer, mask_token=tokenizer.mask_token, collapse_mask_token=True)
explainer = shap.Explainer(predict, masker, output_names=output_names)

# Streamlit UI
st.title("🎯 BERT Sentiment Analysis with SHAP")
st.markdown("""
**How it works:**
1. Enter text in the box below  
2. See predicted sentiment (1-5 stars)  
3. View confidence scores and word-level explanations  
""")

text_input = st.text_area("Input Text", placeholder="Enter text to analyze...", height=100)

if st.button("Analyze Sentiment"):
    if text_input.strip():
        with st.spinner("Analyzing..."):
            # Get predictions
            probabilities = predict([text_input])[0]
            predicted_class = np.argmax(probabilities)
            
            # Display results
            st.subheader("πŸ“Š Results")
            cols = st.columns(2)
            cols[0].metric("Predicted Sentiment", output_names[predicted_class])
            with cols[1]:
                st.markdown("**Confidence Scores**")
                for label, score in zip(output_names, probabilities):
                    st.progress(float(score), text=f"{label}: {score:.1%}")
            
            # Generate SHAP explanations
            st.subheader("πŸ” Explanation")
            st.markdown("""
            **Feature importance (word-level impacts)**  
            πŸ”΄ Higher positive values β†’ Increases sentiment  
            πŸ”΅ Lower negative values β†’ Decreases sentiment
            """)

            # Get SHAP values for the input text
            shap_values = explainer([text_input])
            
            # Create tabs for each sentiment class
            tabs = st.tabs(output_names)
            for i, tab in enumerate(tabs):
                with tab:
                    # Extract the values and corresponding tokens for our single example.
                    # shap_values is of shape (1, num_tokens, num_classes)
                    values = shap_values.values[0, :, i]  # SHAP values for class i
                    tokens = shap_values.data[0]          # Tokenized words

                    # Create a DataFrame to sort and plot the tokens by importance
                    df = pd.DataFrame({"token": tokens, "shap_value": values})
                    # Sort tokens by the absolute SHAP value (smallest at the bottom for horizontal bar plot)
                    df = df.sort_values("shap_value", key=lambda x: np.abs(x), ascending=True)

                    # Create a horizontal bar plot
                    fig, ax = plt.subplots(figsize=(8, max(4, len(tokens) * 0.3)))
                    ax.barh(df["token"], df["shap_value"], color='skyblue')
                    ax.set_xlabel("SHAP value")
                    ax.set_title(f"SHAP bar plot for class '{output_names[i]}'")
                    st.pyplot(fig)
                    plt.close(fig)

    else:
        st.warning("Please enter some text to analyze")

st.markdown("---")
st.markdown("Example texts to try:")
examples = st.columns(4)
example_texts = [
    "This product exceeded all my expectations!",
    "Terrible customer service experience.",
    "The movie was okay, nothing special.",
    "You are kinda cool"
]

for col, text in zip(examples, example_texts):
    with col:
        if st.button(text, use_container_width=True):
            st.session_state.last_input = text

if 'last_input' in st.session_state:
    text_input = st.text_area("", value=st.session_state.last_input, height=100)