File size: 3,792 Bytes
75893d7
 
 
 
 
 
 
 
 
 
 
 
 
7988322
 
 
 
 
 
 
 
75893d7
7988322
75893d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
039a1e6
 
 
 
 
 
3679668
 
 
039a1e6
75893d7
039a1e6
 
 
75893d7
039a1e6
 
75893d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
039a1e6
75893d7
 
 
 
 
 
 
039a1e6
75893d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7299bce
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import shap
import torch
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
model = AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")

# Define prediction function
def predict(texts):
    processed_texts = []
    for text in texts:
        if isinstance(text, list):
            processed_text = tokenizer.convert_tokens_to_string(text)
        else:
            processed_text = text
        processed_texts.append(processed_text)
    
    inputs = tokenizer(
        processed_texts, 
        return_tensors="pt", 
        padding=True, 
        truncation=True, 
        max_length=512,
        add_special_tokens=True
    )
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
    return probabilities.numpy()

# Initialize SHAP components
output_names_list = [model.config.id2label[i] for i in range(len(model.config.id2label))]
masker = shap.maskers.Text(tokenizer=tokenizer, mask_token=tokenizer.mask_token, collapse_mask_token=True)
explainer = shap.Explainer(model=predict, masker=masker, output_names=output_names_list)

def analyze_text(text):
    # Get predictions
    probabilities = predict([text])[0]
    predicted_class = np.argmax(probabilities)
    predicted_label = model.config.id2label[predicted_class]
    
    # Generate SHAP explanations
    shap_values = explainer([text])
    
    # Create HTML visualizations for all classes
    html_plots = []
    for i in range(shap_values.shape[-1]):
        # Create SHAP text plot and convert to HTML
        plot_html = shap.plots.text(shap_values[0, :, i], display=False)
        html_plots.append(plot_html)
    
    # Format confidence scores
    confidence_scores = {model.config.id2label[i]: float(probabilities[i]) 
                        for i in range(len(probabilities))}
    
    return (predicted_label, 
            confidence_scores, 
            *html_plots)

# Create Gradio interface with HTML components
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("## πŸ” BERT Sentiment Analysis with SHAP Explanations")
    
    with gr.Row():
        input_text = gr.Textbox(label="Input Text", placeholder="Enter text to analyze...")
        
    with gr.Row():
        predict_btn = gr.Button("Analyze Sentiment")
        
    with gr.Row():
        label_output = gr.Label(label="Predicted Sentiment")
        prob_output = gr.Label(label="Confidence Scores")
        
    with gr.Row():
        gr.Markdown("""
        ### SHAP Explanations
        Below you can see how each word contributes to different sentiment scores (1-5 stars).  
        Red text increases the score, blue decreases it.
        """)
    
    # Individual Explanation Rows
    plot_components = []
    for i in range(5):
        with gr.Row():
            plot_components.append(
                gr.HTML(
                    label=f"Explanation for {model.config.id2label[i]}",
                    elem_classes=f"shap-plot-{i+1}"
                )
            )

    predict_btn.click(
        fn=analyze_text,
        inputs=input_text,
        outputs=[label_output, prob_output] + plot_components
    )
    
    examples = gr.Examples(
        examples=[
            ["This product exceeded all my expectations!"],
            ["Terrible customer service experience."],
            ["The movie was okay, nothing special."],
            ["You are kinda cool"],
        ],
        inputs=input_text
    )

if __name__ == "__main__":
    demo.launch(debug = True)