broadfield-dev commited on
Commit
8d87543
·
verified ·
1 Parent(s): aed33df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -25
app.py CHANGED
@@ -6,6 +6,7 @@ import matplotlib.pyplot as plt
6
  from transformers import BertTokenizer, BertModel
7
  from sklearn.manifold import TSNE
8
  import seaborn as sns
 
9
  import io
10
  import base64
11
  import logging
@@ -33,16 +34,18 @@ for name, layer in model.named_modules():
33
  if 'layer' in name or 'embeddings' in name:
34
  layer.register_forward_hook(lambda m, i, o, n=name: hook_fn(m, i, o, n))
35
 
36
- def process_input(input_text, visualize_option):
37
  """
38
- Process input text and generate visualizations for BERT embeddings or attention.
39
  Parameters:
40
  - input_text: User-provided text
41
- - visualize_option: 'Embeddings' or 'Attention'
 
 
42
  Returns:
43
- - Base64-encoded plot image (str)
44
- - Dataframe dictionary (dict)
45
- - Status message (str)
46
  """
47
  global activations
48
  activations = {} # Reset activations
@@ -50,7 +53,7 @@ def process_input(input_text, visualize_option):
50
  try:
51
  # Validate input
52
  if not input_text.strip():
53
- return None, {"Error": ["Input text cannot be empty."]}, "Error: Input text cannot be empty."
54
 
55
  # Tokenize input
56
  inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True, max_length=512)
@@ -67,8 +70,8 @@ def process_input(input_text, visualize_option):
67
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
68
 
69
  # Initialize outputs
70
- plot_data = None
71
- dataframe = None
72
 
73
  # Visualization: Embeddings (t-SNE)
74
  if visualize_option == "Embeddings":
@@ -86,7 +89,7 @@ def process_input(input_text, visualize_option):
86
  plt.savefig(buf, format='png', bbox_inches='tight')
87
  buf.seek(0)
88
  img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
89
- plot_data = f"data:image/png;base64,{img_base64}"
90
  plt.close()
91
  # Dataframe for coordinates
92
  dataframe = pd.DataFrame({
@@ -94,13 +97,14 @@ def process_input(input_text, visualize_option):
94
  "t-SNE X": reduced[:, 0],
95
  "t-SNE Y": reduced[:, 1]
96
  }).to_dict()
 
97
  except Exception as e:
98
  logger.warning(f"t-SNE failed: {e}")
99
- dataframe = {"Error": [str(e)]}
100
- return None, dataframe, f"Error: t-SNE computation failed: {e}"
101
  else:
102
- dataframe = {"Error": ["Too few tokens for t-SNE."]}
103
- return None, dataframe, "Error: Too few tokens for t-SNE."
104
 
105
  # Visualization: Attention Weights
106
  elif visualize_option == "Attention":
@@ -115,25 +119,88 @@ def process_input(input_text, visualize_option):
115
  plt.savefig(buf, format='png', bbox_inches='tight')
116
  buf.seek(0)
117
  img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
118
- plot_data = f"data:image/png;base64,{img_base64}"
119
  plt.close()
120
  # Dataframe for attention weights
121
  dataframe = pd.DataFrame(attn, index=tokens, columns=tokens).to_dict()
 
122
  else:
123
- dataframe = {"Error": ["No attention weights available."]}
124
- return None, dataframe, "Error: No attention weights available."
125
-
126
- return plot_data, dataframe, "Processing complete."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  except Exception as e:
129
  logger.error(f"Processing failed: {e}")
130
- return None, {"Error": [str(e)]}, f"Error: {e}"
131
 
132
  # Gradio Interface
133
  def create_gradio_interface():
134
  with gr.Blocks(title="Neural Network Visualization Demo") as demo:
135
  gr.Markdown("# Neural Network Visualization Demo")
136
- gr.Markdown("Visualize BERT embeddings or attention weights. Enter text and select a visualization type.")
137
 
138
  with gr.Row():
139
  with gr.Column():
@@ -142,21 +209,33 @@ def create_gradio_interface():
142
  value="The quick brown fox jumps over the lazy dog.",
143
  placeholder="Enter text here..."
144
  )
 
 
 
 
 
145
  visualize_option = gr.Radio(
146
  label="Visualization Type",
147
- choices=["Embeddings", "Attention"],
148
  value="Embeddings"
149
  )
 
 
 
 
 
 
 
150
  submit_btn = gr.Button("Analyze")
151
 
152
  with gr.Column():
153
- plot_output = gr.Image(label="Visualization", type="pil")
154
- dataframe_output = gr.Dataframe(label="Data Output")
155
  text_output = gr.Textbox(label="Messages")
156
 
157
  submit_btn.click(
158
  fn=process_input,
159
- inputs=[input_text, visualize_option],
160
  outputs=[plot_output, dataframe_output, text_output]
161
  )
162
 
 
6
  from transformers import BertTokenizer, BertModel
7
  from sklearn.manifold import TSNE
8
  import seaborn as sns
9
+ from captum.attr import IntegratedGradients
10
  import io
11
  import base64
12
  import logging
 
34
  if 'layer' in name or 'embeddings' in name:
35
  layer.register_forward_hook(lambda m, i, o, n=name: hook_fn(m, i, o, n))
36
 
37
+ def process_input(input_text, layer_name, visualize_option, attribution_target=0):
38
  """
39
+ Process input text, compute embeddings, activations, attention, and attribution.
40
  Parameters:
41
  - input_text: User-provided text
42
+ - layer_name: Selected layer for activation visualization
43
+ - visualize_option: 'Embeddings', 'Attention', or 'Activations'
44
+ - attribution_target: Target class for attribution (0 or 1)
45
  Returns:
46
+ - HTML string with base64-encoded image(s)
47
+ - List of dataframe dictionaries
48
+ - Status message
49
  """
50
  global activations
51
  activations = {} # Reset activations
 
53
  try:
54
  # Validate input
55
  if not input_text.strip():
56
+ return "<p>Error: Input text cannot be empty.</p>", [{"Error": ["Input text cannot be empty."]}], "Error: Input text cannot be empty."
57
 
58
  # Tokenize input
59
  inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True, max_length=512)
 
70
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
71
 
72
  # Initialize outputs
73
+ html_plots = []
74
+ dataframes = []
75
 
76
  # Visualization: Embeddings (t-SNE)
77
  if visualize_option == "Embeddings":
 
89
  plt.savefig(buf, format='png', bbox_inches='tight')
90
  buf.seek(0)
91
  img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
92
+ html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="t-SNE Plot" style="max-width:100%;"/>')
93
  plt.close()
94
  # Dataframe for coordinates
95
  dataframe = pd.DataFrame({
 
97
  "t-SNE X": reduced[:, 0],
98
  "t-SNE Y": reduced[:, 1]
99
  }).to_dict()
100
+ dataframes.append(dataframe)
101
  except Exception as e:
102
  logger.warning(f"t-SNE failed: {e}")
103
+ dataframes.append({"Error": [str(e)]})
104
+ html_plots.append("<p>Error: t-SNE computation failed.</p>")
105
  else:
106
+ dataframes.append({"Error": ["Too few tokens for t-SNE."]})
107
+ html_plots.append("<p>Error: Too few tokens for t-SNE.</p>")
108
 
109
  # Visualization: Attention Weights
110
  elif visualize_option == "Attention":
 
119
  plt.savefig(buf, format='png', bbox_inches='tight')
120
  buf.seek(0)
121
  img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
122
+ html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="Attention Heatmap" style="max-width:100%;"/>')
123
  plt.close()
124
  # Dataframe for attention weights
125
  dataframe = pd.DataFrame(attn, index=tokens, columns=tokens).to_dict()
126
+ dataframes.append(dataframe)
127
  else:
128
+ dataframes.append({"Error": ["No attention weights available."]})
129
+ html_plots.append("<p>Error: No attention weights available.</p>")
130
+
131
+ # Visualization: Activations
132
+ elif visualize_option == "Activations":
133
+ if layer_name in activations:
134
+ act = activations[layer_name]
135
+ if isinstance(act, tuple):
136
+ act = act[0]
137
+ act = act[0].detach().numpy() # [seq_len, hidden_size]
138
+ dataframe = pd.DataFrame(act, index=tokens).to_dict()
139
+ dataframes.append(dataframe)
140
+ # Plot mean activation per token
141
+ fig, ax = plt.subplots(figsize=(8, 6))
142
+ mean_act = np.mean(act, axis=1)
143
+ ax.bar(range(len(mean_act)), mean_act)
144
+ ax.set_xticks(range(len(mean_act)))
145
+ ax.set_xticklabels(tokens, rotation=45)
146
+ ax.set_title(f"Mean Activations in {layer_name}")
147
+ buf = io.BytesIO()
148
+ plt.savefig(buf, format='png', bbox_inches='tight')
149
+ buf.seek(0)
150
+ img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
151
+ html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="Activations Plot" style="max-width:100%;"/>')
152
+ plt.close()
153
+ else:
154
+ dataframes.append({"Error": [f"Layer {layer_name} not found."]})
155
+ html_plots.append(f"<p>Error: Layer {layer_name} not found.</p>")
156
+
157
+ # Attribution: Integrated Gradients
158
+ def forward_func(inputs, attention_mask=None):
159
+ outputs = model(inputs, attention_mask=attention_mask)
160
+ return outputs.pooler_output[:, int(attribution_target)]
161
+
162
+ ig = IntegratedGradients(forward_func)
163
+ try:
164
+ attributions, _ = ig.attribute(
165
+ inputs=input_ids,
166
+ additional_forward_args=(attention_mask,),
167
+ target=int(attribution_target),
168
+ return_convergence_delta=True
169
+ )
170
+ attr = attributions[0].detach().numpy().sum(axis=1)
171
+ attr_df = pd.DataFrame({"Token": tokens, "Attribution": attr}).to_dict()
172
+ dataframes.append(attr_df)
173
+ # Plot attributions
174
+ fig, ax = plt.subplots(figsize=(8, 6))
175
+ ax.bar(range(len(attr)), attr)
176
+ ax.set_xticks(range(len(attr)))
177
+ ax.set_xticklabels(tokens, rotation=45)
178
+ ax.set_title("Integrated Gradients Attribution")
179
+ buf = io.BytesIO()
180
+ plt.savefig(buf, format='png', bbox_inches='tight')
181
+ buf.seek(0)
182
+ img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
183
+ html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="Attribution Plot" style="max-width:100%;"/>')
184
+ plt.close()
185
+ except Exception as e:
186
+ logger.warning(f"Integrated Gradients failed: {e}")
187
+ dataframes.append({"Error": [str(e)]})
188
+ html_plots.append("<p>Error: Attribution computation failed.</p>")
189
+
190
+ # Combine HTML plots
191
+ html_output = "<div>" + "".join(html_plots) + "</div>"
192
+
193
+ return html_output, dataframes, "Processing complete."
194
 
195
  except Exception as e:
196
  logger.error(f"Processing failed: {e}")
197
+ return f"<p>Error: {e}</p>", [{"Error": [str(e)]}], f"Error: {e}"
198
 
199
  # Gradio Interface
200
  def create_gradio_interface():
201
  with gr.Blocks(title="Neural Network Visualization Demo") as demo:
202
  gr.Markdown("# Neural Network Visualization Demo")
203
+ gr.Markdown("Analyze BERT's neural network paths. Enter text, select a layer, and choose a visualization.")
204
 
205
  with gr.Row():
206
  with gr.Column():
 
209
  value="The quick brown fox jumps over the lazy dog.",
210
  placeholder="Enter text here..."
211
  )
212
+ layer_name = gr.Dropdown(
213
+ label="Select Layer",
214
+ choices=[str(name) for name, _ in model.named_modules() if 'layer' in name or 'embeddings' in name],
215
+ value="embeddings"
216
+ )
217
  visualize_option = gr.Radio(
218
  label="Visualization Type",
219
+ choices=["Embeddings", "Attention", "Activations"],
220
  value="Embeddings"
221
  )
222
+ attribution_target = gr.Slider(
223
+ label="Attribution Target Class (0 or 1)",
224
+ minimum=0,
225
+ maximum=1,
226
+ step=1,
227
+ value=0
228
+ )
229
  submit_btn = gr.Button("Analyze")
230
 
231
  with gr.Column():
232
+ plot_output = gr.HTML(label="Visualizations")
233
+ dataframe_output = gr.Dataframe(label="Data Outputs")
234
  text_output = gr.Textbox(label="Messages")
235
 
236
  submit_btn.click(
237
  fn=process_input,
238
+ inputs=[input_text, layer_name, visualize_option, attribution_target],
239
  outputs=[plot_output, dataframe_output, text_output]
240
  )
241