Ashed00 commited on
Commit
3679668
Β·
verified Β·
1 Parent(s): 0f79fd5

Update app.py

Browse files

Update change HTML to Images.

Files changed (1) hide show
  1. app.py +25 -28
app.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  import gradio as gr
5
  import matplotlib.pyplot as plt
6
  import numpy as np
 
7
 
8
  # Load model and tokenizer
9
  tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
@@ -11,16 +12,8 @@ model = AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-mu
11
 
12
  # Define prediction function
13
  def predict(texts):
14
- processed_texts = []
15
- for text in texts:
16
- if isinstance(text, list):
17
- processed_text = tokenizer.convert_tokens_to_string(text)
18
- else:
19
- processed_text = text
20
- processed_texts.append(processed_text)
21
-
22
  inputs = tokenizer(
23
- processed_texts,
24
  return_tensors="pt",
25
  padding=True,
26
  truncation=True,
@@ -48,25 +41,28 @@ def analyze_text(text):
48
  # Generate SHAP explanations
49
  shap_values = explainer([text])
50
 
51
- # Create HTML visualizations for all classes
52
- html_plots = []
53
- for i in range(shap_values.shape[-1]):
54
- # Generate SHAP plot HTML
55
- plot_html = shap.plots.text(shap_values[0, :, i], display=False)
 
 
56
 
57
- # Create HTML visualizations for all classes
58
- html_plots = []
59
- for i in range(shap_values.shape[-1]):
60
- # Create SHAP text plot and convert to HTML
61
- plot_html = shap.plots.text(shap_values[0, :, i], display=False)
62
- html_plots.append(plot_html)
63
- # Format confidence scores
64
- confidence_scores = {model.config.id2label[i]: float(probabilities[i])
65
- for i in range(len(probabilities))}
 
66
 
67
- return (predicted_label, confidence_scores, *html_plots)
68
 
69
- # Create Gradio interface with HTML components
70
  with gr.Blocks() as demo:
71
  gr.Markdown("## πŸ” BERT Sentiment Analysis with SHAP Explanations")
72
 
@@ -84,7 +80,7 @@ with gr.Blocks() as demo:
84
  gr.Markdown("""
85
  ### SHAP Explanations
86
  Below you can see how each word contributes to different sentiment scores (1-5 stars).
87
- Red text increases the score, blue decreases it.
88
  """)
89
 
90
  # Individual Explanation Rows
@@ -92,8 +88,9 @@ with gr.Blocks() as demo:
92
  for i in range(5):
93
  with gr.Row():
94
  plot_components.append(
95
- gr.HTML(
96
  label=f"Explanation for {model.config.id2label[i]}",
 
97
  elem_classes=f"shap-plot-{i+1}"
98
  )
99
  )
@@ -115,4 +112,4 @@ with gr.Blocks() as demo:
115
  )
116
 
117
  if __name__ == "__main__":
118
- demo.launch(share = True, ssr_mode = False, )
 
4
  import gradio as gr
5
  import matplotlib.pyplot as plt
6
  import numpy as np
7
+ from io import BytesIO
8
 
9
  # Load model and tokenizer
10
  tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
 
12
 
13
  # Define prediction function
14
  def predict(texts):
 
 
 
 
 
 
 
 
15
  inputs = tokenizer(
16
+ texts,
17
  return_tensors="pt",
18
  padding=True,
19
  truncation=True,
 
41
  # Generate SHAP explanations
42
  shap_values = explainer([text])
43
 
44
+ # Create matplotlib figures for each class
45
+ images = []
46
+ for i in range(5):
47
+ plt.figure(figsize=(10, 3))
48
+ shap.plots.bar(shap_values[0, :, i], show=False)
49
+ plt.title(f"Feature importance for {output_names_list[i]}")
50
+ plt.tight_layout()
51
 
52
+ # Save plot to in-memory buffer
53
+ buf = BytesIO()
54
+ plt.savefig(buf, format="png", bbox_inches="tight")
55
+ plt.close()
56
+ buf.seek(0)
57
+ images.append(buf)
58
+
59
+ # Format confidence scores
60
+ confidence_scores = {model.config.id2label[i]: float(probabilities[i])
61
+ for i in range(len(probabilities))}
62
 
63
+ return (predicted_label, confidence_scores, *images)
64
 
65
+ # Create Gradio interface with image components
66
  with gr.Blocks() as demo:
67
  gr.Markdown("## πŸ” BERT Sentiment Analysis with SHAP Explanations")
68
 
 
80
  gr.Markdown("""
81
  ### SHAP Explanations
82
  Below you can see how each word contributes to different sentiment scores (1-5 stars).
83
+ Positive values increase the score, negative values decrease it.
84
  """)
85
 
86
  # Individual Explanation Rows
 
88
  for i in range(5):
89
  with gr.Row():
90
  plot_components.append(
91
+ gr.Image(
92
  label=f"Explanation for {model.config.id2label[i]}",
93
+ type="pil",
94
  elem_classes=f"shap-plot-{i+1}"
95
  )
96
  )
 
112
  )
113
 
114
  if __name__ == "__main__":
115
+ demo.launch(share=True)