broadfield-dev commited on
Commit
aed33df
·
verified ·
1 Parent(s): c1b4423

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -92
app.py CHANGED
@@ -6,10 +6,8 @@ import matplotlib.pyplot as plt
6
  from transformers import BertTokenizer, BertModel
7
  from sklearn.manifold import TSNE
8
  import seaborn as sns
9
- from captum.attr import IntegratedGradients
10
  import io
11
  import base64
12
- from PIL import Image
13
  import logging
14
 
15
  # Set up logging
@@ -35,13 +33,16 @@ for name, layer in model.named_modules():
35
  if 'layer' in name or 'embeddings' in name:
36
  layer.register_forward_hook(lambda m, i, o, n=name: hook_fn(m, i, o, n))
37
 
38
- def process_input(input_text, layer_name, visualize_option, attribution_target=0):
39
  """
40
- Process input text, compute embeddings, activations, and visualizations.
 
 
 
41
  Returns:
42
- - List of base64-encoded plot images
43
- - List of dictionaries for dataframe display
44
- - Status message
45
  """
46
  global activations
47
  activations = {} # Reset activations
@@ -49,7 +50,7 @@ def process_input(input_text, layer_name, visualize_option, attribution_target=0
49
  try:
50
  # Validate input
51
  if not input_text.strip():
52
- return [], [], "Error: Input text cannot be empty."
53
 
54
  # Tokenize input
55
  inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True, max_length=512)
@@ -66,141 +67,96 @@ def process_input(input_text, layer_name, visualize_option, attribution_target=0
66
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
67
 
68
  # Initialize outputs
69
- plots = []
70
- dataframes = []
71
 
72
  # Visualization: Embeddings (t-SNE)
73
  if visualize_option == "Embeddings":
74
- emb = embeddings[0].detach().numpy()
75
  if emb.shape[0] > 1:
76
  try:
77
  tsne = TSNE(n_components=2, random_state=42, perplexity=min(5, emb.shape[0]-1))
78
  reduced = tsne.fit_transform(emb)
79
- fig, ax = plt.subplots()
80
  ax.scatter(reduced[:, 0], reduced[:, 1], c='blue')
81
  for i, token in enumerate(tokens):
82
  ax.annotate(token, (reduced[i, 0], reduced[i, 1]))
83
  ax.set_title("t-SNE of Token Embeddings")
84
  buf = io.BytesIO()
85
- plt.savefig(buf, format='png')
86
  buf.seek(0)
87
  img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
88
- plots.append(f"data:image/png;base64,{img_base64}")
89
  plt.close()
 
 
 
 
 
 
90
  except Exception as e:
91
  logger.warning(f"t-SNE failed: {e}")
92
- dataframes.append({"Error": ["t-SNE could not be computed."]})
 
 
 
 
93
 
94
  # Visualization: Attention Weights
95
- if visualize_option == "Attention":
96
  if attentions:
97
- attn = attentions[-1][0, 0].detach().numpy()
98
- fig, ax = plt.subplots()
99
  sns.heatmap(attn, xticklabels=tokens, yticklabels=tokens, cmap='viridis', ax=ax)
100
  ax.set_title("Attention Weights (Last Layer, Head 0)")
101
  plt.xticks(rotation=45)
102
  plt.yticks(rotation=0)
103
  buf = io.BytesIO()
104
- plt.savefig(buf, format='png')
105
  buf.seek(0)
106
  img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
107
- plots.append(f"data:image/png;base64,{img_base64}")
108
  plt.close()
 
 
 
 
 
109
 
110
- # Visualization: Activations
111
- if visualize_option == "Activations" and layer_name in activations:
112
- act = activations[layer_name]
113
- if isinstance(act, tuple):
114
- act = act[0]
115
- act = act[0].detach().numpy()
116
- df = pd.DataFrame(act, index=tokens)
117
- dataframes.append(df.to_dict()) # Convert to dict for serialization
118
- fig, ax = plt.subplots()
119
- mean_act = np.mean(act, axis=1)
120
- ax.bar(range(len(mean_act)), mean_act)
121
- ax.set_xticks(range(len(mean_act)))
122
- ax.set_xticklabels(tokens, rotation=45)
123
- ax.set_title(f"Mean Activations in {layer_name}")
124
- buf = io.BytesIO()
125
- plt.savefig(buf, format='png')
126
- buf.seek(0)
127
- img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
128
- plots.append(f"data:image/png;base64,{img_base64}")
129
- plt.close()
130
-
131
- # Attribution: Integrated Gradients
132
- def forward_func(inputs, attention_mask=None):
133
- outputs = model(inputs, attention_mask=attention_mask)
134
- return outputs.pooler_output[:, int(attribution_target)]
135
-
136
- ig = IntegratedGradients(forward_func)
137
- try:
138
- attributions, _ = ig.attribute(
139
- inputs=input_ids,
140
- additional_forward_args=(attention_mask,),
141
- target=int(attribution_target),
142
- return_convergence_delta=True
143
- )
144
- attr = attributions[0].detach().numpy().sum(axis=1)
145
- attr_df = pd.DataFrame({"Token": tokens, "Attribution": attr})
146
- dataframes.append(attr_df.to_dict())
147
- fig, ax = plt.subplots()
148
- ax.bar(range(len(attr)), attr)
149
- ax.set_xticks(range(len(attr)))
150
- ax.set_xticklabels(tokens, rotation=45)
151
- ax.set_title("Integrated Gradients Attribution")
152
- buf = io.BytesIO()
153
- plt.savefig(buf, format='png')
154
- buf.seek(0)
155
- img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
156
- plots.append(f"data:image/png;base64,{img_base64}")
157
- plt.close()
158
- except Exception as e:
159
- logger.warning(f"Integrated Gradients failed: {e}")
160
- dataframes.append({"Error": ["Attribution could not be computed."]})
161
-
162
- return plots, dataframes, "Processing complete."
163
 
164
  except Exception as e:
165
  logger.error(f"Processing failed: {e}")
166
- return [], [{"Error": [str(e)]}], f"Error: {e}"
167
 
168
  # Gradio Interface
169
  def create_gradio_interface():
170
  with gr.Blocks(title="Neural Network Visualization Demo") as demo:
171
  gr.Markdown("# Neural Network Visualization Demo")
172
- gr.Markdown("Analyze BERT's neural network paths. Enter text, select a layer, and choose a visualization.")
173
 
174
  with gr.Row():
175
  with gr.Column():
176
- input_text = gr.Textbox(label="Input Text", value="The quick brown fox jumps over the lazy dog.")
177
- layer_name = gr.Dropdown(
178
- label="Select Layer",
179
- choices=[str(name) for name, _ in model.named_modules() if 'layer' in name or 'embeddings' in name],
180
- value="embeddings"
181
  )
182
  visualize_option = gr.Radio(
183
  label="Visualization Type",
184
- choices=["Embeddings", "Attention", "Activations"],
185
  value="Embeddings"
186
  )
187
- attribution_target = gr.Slider(
188
- label="Attribution Target Class (0 or 1)",
189
- minimum=0,
190
- maximum=1,
191
- step=1,
192
- value=0
193
- )
194
  submit_btn = gr.Button("Analyze")
195
 
196
  with gr.Column():
197
- plot_output = gr.Gallery(label="Visualizations")
198
- dataframe_output = gr.Dataframe(label="Data Outputs")
199
  text_output = gr.Textbox(label="Messages")
200
 
201
  submit_btn.click(
202
  fn=process_input,
203
- inputs=[input_text, layer_name, visualize_option, attribution_target],
204
  outputs=[plot_output, dataframe_output, text_output]
205
  )
206
 
@@ -213,4 +169,4 @@ if __name__ == "__main__":
213
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
214
  except Exception as e:
215
  logger.error(f"Failed to launch Gradio demo: {e}")
216
- print(f"Error launching demo: {e}. Try running locally without share=True.")
 
6
  from transformers import BertTokenizer, BertModel
7
  from sklearn.manifold import TSNE
8
  import seaborn as sns
 
9
  import io
10
  import base64
 
11
  import logging
12
 
13
  # Set up logging
 
33
  if 'layer' in name or 'embeddings' in name:
34
  layer.register_forward_hook(lambda m, i, o, n=name: hook_fn(m, i, o, n))
35
 
36
+ def process_input(input_text, visualize_option):
37
  """
38
+ Process input text and generate visualizations for BERT embeddings or attention.
39
+ Parameters:
40
+ - input_text: User-provided text
41
+ - visualize_option: 'Embeddings' or 'Attention'
42
  Returns:
43
+ - Base64-encoded plot image (str)
44
+ - Dataframe dictionary (dict)
45
+ - Status message (str)
46
  """
47
  global activations
48
  activations = {} # Reset activations
 
50
  try:
51
  # Validate input
52
  if not input_text.strip():
53
+ return None, {"Error": ["Input text cannot be empty."]}, "Error: Input text cannot be empty."
54
 
55
  # Tokenize input
56
  inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True, max_length=512)
 
67
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
68
 
69
  # Initialize outputs
70
+ plot_data = None
71
+ dataframe = None
72
 
73
  # Visualization: Embeddings (t-SNE)
74
  if visualize_option == "Embeddings":
75
+ emb = embeddings[0].detach().numpy() # [seq_len, hidden_size]
76
  if emb.shape[0] > 1:
77
  try:
78
  tsne = TSNE(n_components=2, random_state=42, perplexity=min(5, emb.shape[0]-1))
79
  reduced = tsne.fit_transform(emb)
80
+ fig, ax = plt.subplots(figsize=(8, 6))
81
  ax.scatter(reduced[:, 0], reduced[:, 1], c='blue')
82
  for i, token in enumerate(tokens):
83
  ax.annotate(token, (reduced[i, 0], reduced[i, 1]))
84
  ax.set_title("t-SNE of Token Embeddings")
85
  buf = io.BytesIO()
86
+ plt.savefig(buf, format='png', bbox_inches='tight')
87
  buf.seek(0)
88
  img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
89
+ plot_data = f"data:image/png;base64,{img_base64}"
90
  plt.close()
91
+ # Dataframe for coordinates
92
+ dataframe = pd.DataFrame({
93
+ "Token": tokens,
94
+ "t-SNE X": reduced[:, 0],
95
+ "t-SNE Y": reduced[:, 1]
96
+ }).to_dict()
97
  except Exception as e:
98
  logger.warning(f"t-SNE failed: {e}")
99
+ dataframe = {"Error": [str(e)]}
100
+ return None, dataframe, f"Error: t-SNE computation failed: {e}"
101
+ else:
102
+ dataframe = {"Error": ["Too few tokens for t-SNE."]}
103
+ return None, dataframe, "Error: Too few tokens for t-SNE."
104
 
105
  # Visualization: Attention Weights
106
+ elif visualize_option == "Attention":
107
  if attentions:
108
+ attn = attentions[-1][0, 0].detach().numpy() # Last layer, first head
109
+ fig, ax = plt.subplots(figsize=(8, 6))
110
  sns.heatmap(attn, xticklabels=tokens, yticklabels=tokens, cmap='viridis', ax=ax)
111
  ax.set_title("Attention Weights (Last Layer, Head 0)")
112
  plt.xticks(rotation=45)
113
  plt.yticks(rotation=0)
114
  buf = io.BytesIO()
115
+ plt.savefig(buf, format='png', bbox_inches='tight')
116
  buf.seek(0)
117
  img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
118
+ plot_data = f"data:image/png;base64,{img_base64}"
119
  plt.close()
120
+ # Dataframe for attention weights
121
+ dataframe = pd.DataFrame(attn, index=tokens, columns=tokens).to_dict()
122
+ else:
123
+ dataframe = {"Error": ["No attention weights available."]}
124
+ return None, dataframe, "Error: No attention weights available."
125
 
126
+ return plot_data, dataframe, "Processing complete."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  except Exception as e:
129
  logger.error(f"Processing failed: {e}")
130
+ return None, {"Error": [str(e)]}, f"Error: {e}"
131
 
132
  # Gradio Interface
133
  def create_gradio_interface():
134
  with gr.Blocks(title="Neural Network Visualization Demo") as demo:
135
  gr.Markdown("# Neural Network Visualization Demo")
136
+ gr.Markdown("Visualize BERT embeddings or attention weights. Enter text and select a visualization type.")
137
 
138
  with gr.Row():
139
  with gr.Column():
140
+ input_text = gr.Textbox(
141
+ label="Input Text",
142
+ value="The quick brown fox jumps over the lazy dog.",
143
+ placeholder="Enter text here..."
 
144
  )
145
  visualize_option = gr.Radio(
146
  label="Visualization Type",
147
+ choices=["Embeddings", "Attention"],
148
  value="Embeddings"
149
  )
 
 
 
 
 
 
 
150
  submit_btn = gr.Button("Analyze")
151
 
152
  with gr.Column():
153
+ plot_output = gr.Image(label="Visualization", type="pil")
154
+ dataframe_output = gr.Dataframe(label="Data Output")
155
  text_output = gr.Textbox(label="Messages")
156
 
157
  submit_btn.click(
158
  fn=process_input,
159
+ inputs=[input_text, visualize_option],
160
  outputs=[plot_output, dataframe_output, text_output]
161
  )
162
 
 
169
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
170
  except Exception as e:
171
  logger.error(f"Failed to launch Gradio demo: {e}")
172
+ print(f"Error launching demo: {e}. Try running locally with a different port or without share=True.")