Ashed00 commited on
Commit
75893d7
·
verified ·
1 Parent(s): cf19128

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -0
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
+ import shap
3
+ import torch
4
+ import gradio as gr
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+
8
+ # Load model and tokenizer
9
+ tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
10
+ model = AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
11
+
12
+ # Define prediction function
13
+ def predict(texts):
14
+ processed_texts = []
15
+ for text in texts:
16
+ if isinstance(text, list):
17
+ processed_text = tokenizer.convert_tokens_to_string(text)
18
+ else:
19
+ processed_text = text
20
+ processed_texts.append(processed_text)
21
+
22
+ inputs = tokenizer(
23
+ processed_texts,
24
+ return_tensors="pt",
25
+ padding=True,
26
+ truncation=True,
27
+ max_length=512,
28
+ add_special_tokens=True
29
+ )
30
+
31
+ with torch.no_grad():
32
+ outputs = model(**inputs)
33
+
34
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
35
+ return probabilities.numpy()
36
+
37
+ # Initialize SHAP components
38
+ output_names_list = [model.config.id2label[i] for i in range(len(model.config.id2label))]
39
+ masker = shap.maskers.Text(tokenizer=tokenizer, mask_token=tokenizer.mask_token, collapse_mask_token=True)
40
+ explainer = shap.Explainer(model=predict, masker=masker, output_names=output_names_list)
41
+
42
+ def analyze_text(text):
43
+ # Get predictions
44
+ probabilities = predict([text])[0]
45
+ predicted_class = np.argmax(probabilities)
46
+ predicted_label = model.config.id2label[predicted_class]
47
+
48
+ # Generate SHAP explanations
49
+ shap_values = explainer([text])
50
+
51
+ # Create HTML visualizations for all classes
52
+ html_plots = []
53
+ for i in range(shap_values.shape[-1]):
54
+ # Create SHAP text plot and convert to HTML
55
+ plot_html = shap.plots.text(shap_values[0, :, i], display=False)
56
+ html_plots.append(plot_html)
57
+
58
+ # Format confidence scores
59
+ confidence_scores = {model.config.id2label[i]: float(probabilities[i])
60
+ for i in range(len(probabilities))}
61
+
62
+ return (predicted_label,
63
+ confidence_scores,
64
+ *html_plots)
65
+
66
+ # Create Gradio interface with HTML components
67
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
68
+ gr.Markdown("## 🔍 BERT Sentiment Analysis with SHAP Explanations")
69
+
70
+ with gr.Row():
71
+ input_text = gr.Textbox(label="Input Text", placeholder="Enter text to analyze...")
72
+
73
+ with gr.Row():
74
+ predict_btn = gr.Button("Analyze Sentiment")
75
+
76
+ with gr.Row():
77
+ label_output = gr.Label(label="Predicted Sentiment")
78
+ prob_output = gr.Label(label="Confidence Scores")
79
+
80
+ with gr.Row():
81
+ gr.Markdown("""
82
+ ### SHAP Explanations
83
+ Below you can see how each word contributes to different sentiment scores (1-5 stars).
84
+ Red text increases the score, blue decreases it.
85
+ """)
86
+
87
+ # Individual Explanation Rows
88
+ plot_components = []
89
+ for i in range(5):
90
+ with gr.Row():
91
+ plot_components.append(
92
+ gr.HTML(
93
+ label=f"Explanation for {model.config.id2label[i]}",
94
+ elem_classes=f"shap-plot-{i+1}"
95
+ )
96
+ )
97
+
98
+ predict_btn.click(
99
+ fn=analyze_text,
100
+ inputs=input_text,
101
+ outputs=[label_output, prob_output] + plot_components
102
+ )
103
+
104
+ examples = gr.Examples(
105
+ examples=[
106
+ ["This product exceeded all my expectations!"],
107
+ ["Terrible customer service experience."],
108
+ ["The movie was okay, nothing special."],
109
+ ["You are kinda cool"],
110
+ ],
111
+ inputs=input_text
112
+ )
113
+
114
+ if __name__ == "__main__":
115
+ demo.launch(debug = True)