broadfield-dev commited on
Commit
c1b4423
·
verified ·
1 Parent(s): 2c4215c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -148
app.py CHANGED
@@ -10,194 +10,174 @@ from captum.attr import IntegratedGradients
10
  import io
11
  import base64
12
  from PIL import Image
 
 
 
 
 
13
 
14
  # Initialize BERT model and tokenizer
15
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
16
- model = BertModel.from_pretrained('bert-base-uncased')
17
- model.eval()
18
-
19
- # Alternative MLP model (uncomment to use instead of BERT)
20
- """
21
- # import torch.nn as nn
22
- # class SimpleMLP(nn.Module):
23
- # def __init__(self, input_size=10, hidden_sizes=[64, 32], output_size=2):
24
- # super(SimpleMLP, self).__init__()
25
- # layers = []
26
- # prev_size = input_size
27
- # for hidden_size in hidden_sizes:
28
- # layers.append(nn.Linear(prev_size, hidden_size))
29
- # layers.append(nn.ReLU())
30
- # prev_size = hidden_size
31
- # layers.append(nn.Linear(prev_size, output_size))
32
- # self.network = nn.Sequential(*layers)
33
- # def forward(self, x):
34
- # return self.network(x)
35
- # model = SimpleMLP()
36
- # model.eval()
37
- """
38
 
39
  # Store intermediate activations
40
  activations = {}
41
  def hook_fn(module, input, output, name):
42
- activations[name] = output
43
 
44
- # Register hooks for BERT layers (or MLP layers)
45
  for name, layer in model.named_modules():
46
- if 'layer' in name or 'embeddings' in name: # Focus on transformer layers
47
  layer.register_forward_hook(lambda m, i, o, n=name: hook_fn(m, i, o, n))
48
- # For MLP, replace with:
49
- # if isinstance(layer, nn.Linear) or isinstance(layer, nn.ReLU):
50
- # layer.register_forward_hook(lambda m, i, o, n=name: hook_fn(m, i, o, n))
51
 
52
  def process_input(input_text, layer_name, visualize_option, attribution_target=0):
53
  """
54
  Process input text, compute embeddings, activations, and visualizations.
55
- Parameters:
56
- - input_text: User-provided text input
57
- - layer_name: Selected layer for visualization
58
- - visualize_option: 'Embeddings', 'Attention', or 'Activations'
59
- - attribution_target: Target class for attribution (0 or 1 for binary classification)
60
  Returns:
61
- - Dictionary with plots and dataframes
 
 
62
  """
63
  global activations
64
  activations = {} # Reset activations
65
 
66
- # Tokenize input
67
- inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True, max_length=512)
68
- input_ids = inputs['input_ids']
69
- attention_mask = inputs['attention_mask']
70
-
71
- # Forward pass
72
- with torch.no_grad():
73
- outputs = model(input_ids, attention_mask=attention_mask, output_attentions=True, output_hidden_states=True)
74
- embeddings = outputs.last_hidden_state # [batch, seq_len, hidden_size]
75
- attentions = outputs.attentions # List of attention weights
76
- hidden_states = outputs.hidden_states # List of hidden states
77
-
78
- # Convert token IDs to tokens
79
- tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
80
-
81
- # Initialize output dictionary
82
- results = {
83
- "plots": [],
84
- "dataframes": [],
85
- "text": []
86
- }
87
-
88
- # Visualization: Embeddings (t-SNE)
89
- if visualize_option == "Embeddings":
90
- emb = embeddings[0].detach().numpy() # [seq_len, hidden_size]
91
- if emb.shape[0] > 1: # Need at least 2 points for t-SNE
92
- tsne = TSNE(n_components=2, random_state=42, perplexity=min(5, emb.shape[0]-1))
93
- reduced = tsne.fit_transform(emb)
94
- fig, ax = plt.subplots()
95
- scatter = ax.scatter(reduced[:, 0], reduced[:, 1], c='blue')
96
- for i, token in enumerate(tokens):
97
- ax.annotate(token, (reduced[i, 0], reduced[i, 1]))
98
- ax.set_title("t-SNE of Token Embeddings")
99
- # Convert plot to base64 for Gradio
100
- buf = io.BytesIO()
101
- plt.savefig(buf, format='png')
102
- buf.seek(0)
103
- img = Image.open(buf)
104
- img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
105
- results["plots"].append(f"data:image/png;base64,{img_base64}")
106
- plt.close()
107
-
108
- # Visualization: Attention Weights
109
- if visualize_option == "Attention":
110
- if attentions:
111
- attn = attentions[-1][0, 0].detach().numpy() # Last layer, first head
112
- fig, ax = plt.subplots()
113
- sns.heatmap(attn, xticklabels=tokens, yticklabels=tokens, cmap='viridis', ax=ax)
114
- ax.set_title("Attention Weights (Last Layer, Head 0)")
115
- plt.xticks(rotation=45)
116
- plt.yticks(rotation=0)
117
- # Convert plot to base64
118
- buf = io.BytesIO()
119
- plt.savefig(buf, format='png')
120
- buf.seek(0)
121
- img = Image.open(buf)
122
- img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
123
- results["plots"].append(f"data:image/png;base64,{img_base64}")
124
- plt.close()
125
-
126
- # Visualization: Activations
127
- if visualize_option == "Activations":
128
- if layer_name in activations:
129
  act = activations[layer_name]
130
- if isinstance(act, tuple): # Handle attention outputs
131
  act = act[0]
132
- act = act[0].detach().numpy() # [seq_len, hidden_size]
133
  df = pd.DataFrame(act, index=tokens)
134
- results["dataframes"].append(df)
135
- # Plot mean activation per token
136
  fig, ax = plt.subplots()
137
  mean_act = np.mean(act, axis=1)
138
  ax.bar(range(len(mean_act)), mean_act)
139
  ax.set_xticks(range(len(mean_act)))
140
  ax.set_xticklabels(tokens, rotation=45)
141
  ax.set_title(f"Mean Activations in {layer_name}")
142
- # Convert plot to base64
143
  buf = io.BytesIO()
144
  plt.savefig(buf, format='png')
145
  buf.seek(0)
146
- img = Image.open(buf)
147
  img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
148
- results["plots"].append(f"data:image/png;base64,{img_base64}")
149
  plt.close()
150
 
151
- # Attribution: Integrated Gradients
152
- def forward_func(inputs, attention_mask=None):
153
- outputs = model(inputs, attention_mask=attention_mask)
154
- return outputs.pooler_output[:, attribution_target]
155
-
156
- ig = IntegratedGradients(forward_func)
157
- attributions, delta = ig.attribute(
158
- inputs=input_ids,
159
- additional_forward_args=(attention_mask,),
160
- target=attribution_target,
161
- return_convergence_delta=True
162
- )
163
- attr = attributions[0].detach().numpy()
164
- attr_df = pd.DataFrame({"Token": tokens, "Attribution": attr.sum(axis=1)})
165
- results["dataframes"].append(attr_df)
166
-
167
- # Plot attributions
168
- fig, ax = plt.subplots()
169
- ax.bar(range(len(attr_df)), attr_df["Attribution"])
170
- ax.set_xticks(range(len(attr_df)))
171
- ax.set_xticklabels(tokens, rotation=45)
172
- ax.set_title("Integrated Gradients Attribution")
173
- buf = io.BytesIO()
174
- plt.savefig(buf, format='png')
175
- buf.seek(0)
176
- img = Image.open(buf)
177
- img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
178
- results["plots"].append(f"data:image/png;base64,{img_base64}")
179
- plt.close()
180
-
181
- return (
182
- results["plots"] if results["plots"] else None,
183
- results["dataframes"] if results["dataframes"] else None,
184
- "\n".join(results["text"]) if results["text"] else "Processing complete."
185
- )
 
186
 
187
  # Gradio Interface
188
  def create_gradio_interface():
189
  with gr.Blocks(title="Neural Network Visualization Demo") as demo:
190
  gr.Markdown("# Neural Network Visualization Demo")
191
- gr.Markdown("Analyze the paths of a BERT model from input to output. Enter text, select a layer, and choose a visualization option.")
192
 
193
  with gr.Row():
194
  with gr.Column():
195
  input_text = gr.Textbox(label="Input Text", value="The quick brown fox jumps over the lazy dog.")
196
  layer_name = gr.Dropdown(
197
  label="Select Layer",
198
- choices=[name for name in model.named_modules() if 'layer' in name or 'embeddings' in name],
199
- value="embeddings",
200
- allow_custom_value=True
201
  )
202
  visualize_option = gr.Radio(
203
  label="Visualization Type",
@@ -205,7 +185,7 @@ def create_gradio_interface():
205
  value="Embeddings"
206
  )
207
  attribution_target = gr.Slider(
208
- label="Attribution Target Class (0 or 1 for binary classification)",
209
  minimum=0,
210
  maximum=1,
211
  step=1,
@@ -226,7 +206,11 @@ def create_gradio_interface():
226
 
227
  return demo
228
 
229
- # Launch the demo
230
  if __name__ == "__main__":
231
- demo = create_gradio_interface()
232
- demo.launch(share=True)
 
 
 
 
 
10
  import io
11
  import base64
12
  from PIL import Image
13
+ import logging
14
+
15
+ # Set up logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
 
19
  # Initialize BERT model and tokenizer
20
+ try:
21
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
22
+ model = BertModel.from_pretrained('bert-base-uncased')
23
+ model.eval()
24
+ except Exception as e:
25
+ logger.error(f"Failed to load BERT model: {e}")
26
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  # Store intermediate activations
29
  activations = {}
30
  def hook_fn(module, input, output, name):
31
+ activations[str(name)] = output # Ensure name is a string
32
 
33
+ # Register hooks for BERT layers
34
  for name, layer in model.named_modules():
35
+ if 'layer' in name or 'embeddings' in name:
36
  layer.register_forward_hook(lambda m, i, o, n=name: hook_fn(m, i, o, n))
 
 
 
37
 
38
  def process_input(input_text, layer_name, visualize_option, attribution_target=0):
39
  """
40
  Process input text, compute embeddings, activations, and visualizations.
 
 
 
 
 
41
  Returns:
42
+ - List of base64-encoded plot images
43
+ - List of dictionaries for dataframe display
44
+ - Status message
45
  """
46
  global activations
47
  activations = {} # Reset activations
48
 
49
+ try:
50
+ # Validate input
51
+ if not input_text.strip():
52
+ return [], [], "Error: Input text cannot be empty."
53
+
54
+ # Tokenize input
55
+ inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True, max_length=512)
56
+ input_ids = inputs['input_ids']
57
+ attention_mask = inputs['attention_mask']
58
+
59
+ # Forward pass
60
+ with torch.no_grad():
61
+ outputs = model(input_ids, attention_mask=attention_mask, output_attentions=True, output_hidden_states=True)
62
+ embeddings = outputs.last_hidden_state # [batch, seq_len, hidden_size]
63
+ attentions = outputs.attentions # List of attention weights
64
+
65
+ # Convert token IDs to tokens
66
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
67
+
68
+ # Initialize outputs
69
+ plots = []
70
+ dataframes = []
71
+
72
+ # Visualization: Embeddings (t-SNE)
73
+ if visualize_option == "Embeddings":
74
+ emb = embeddings[0].detach().numpy()
75
+ if emb.shape[0] > 1:
76
+ try:
77
+ tsne = TSNE(n_components=2, random_state=42, perplexity=min(5, emb.shape[0]-1))
78
+ reduced = tsne.fit_transform(emb)
79
+ fig, ax = plt.subplots()
80
+ ax.scatter(reduced[:, 0], reduced[:, 1], c='blue')
81
+ for i, token in enumerate(tokens):
82
+ ax.annotate(token, (reduced[i, 0], reduced[i, 1]))
83
+ ax.set_title("t-SNE of Token Embeddings")
84
+ buf = io.BytesIO()
85
+ plt.savefig(buf, format='png')
86
+ buf.seek(0)
87
+ img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
88
+ plots.append(f"data:image/png;base64,{img_base64}")
89
+ plt.close()
90
+ except Exception as e:
91
+ logger.warning(f"t-SNE failed: {e}")
92
+ dataframes.append({"Error": ["t-SNE could not be computed."]})
93
+
94
+ # Visualization: Attention Weights
95
+ if visualize_option == "Attention":
96
+ if attentions:
97
+ attn = attentions[-1][0, 0].detach().numpy()
98
+ fig, ax = plt.subplots()
99
+ sns.heatmap(attn, xticklabels=tokens, yticklabels=tokens, cmap='viridis', ax=ax)
100
+ ax.set_title("Attention Weights (Last Layer, Head 0)")
101
+ plt.xticks(rotation=45)
102
+ plt.yticks(rotation=0)
103
+ buf = io.BytesIO()
104
+ plt.savefig(buf, format='png')
105
+ buf.seek(0)
106
+ img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
107
+ plots.append(f"data:image/png;base64,{img_base64}")
108
+ plt.close()
109
+
110
+ # Visualization: Activations
111
+ if visualize_option == "Activations" and layer_name in activations:
112
  act = activations[layer_name]
113
+ if isinstance(act, tuple):
114
  act = act[0]
115
+ act = act[0].detach().numpy()
116
  df = pd.DataFrame(act, index=tokens)
117
+ dataframes.append(df.to_dict()) # Convert to dict for serialization
 
118
  fig, ax = plt.subplots()
119
  mean_act = np.mean(act, axis=1)
120
  ax.bar(range(len(mean_act)), mean_act)
121
  ax.set_xticks(range(len(mean_act)))
122
  ax.set_xticklabels(tokens, rotation=45)
123
  ax.set_title(f"Mean Activations in {layer_name}")
 
124
  buf = io.BytesIO()
125
  plt.savefig(buf, format='png')
126
  buf.seek(0)
 
127
  img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
128
+ plots.append(f"data:image/png;base64,{img_base64}")
129
  plt.close()
130
 
131
+ # Attribution: Integrated Gradients
132
+ def forward_func(inputs, attention_mask=None):
133
+ outputs = model(inputs, attention_mask=attention_mask)
134
+ return outputs.pooler_output[:, int(attribution_target)]
135
+
136
+ ig = IntegratedGradients(forward_func)
137
+ try:
138
+ attributions, _ = ig.attribute(
139
+ inputs=input_ids,
140
+ additional_forward_args=(attention_mask,),
141
+ target=int(attribution_target),
142
+ return_convergence_delta=True
143
+ )
144
+ attr = attributions[0].detach().numpy().sum(axis=1)
145
+ attr_df = pd.DataFrame({"Token": tokens, "Attribution": attr})
146
+ dataframes.append(attr_df.to_dict())
147
+ fig, ax = plt.subplots()
148
+ ax.bar(range(len(attr)), attr)
149
+ ax.set_xticks(range(len(attr)))
150
+ ax.set_xticklabels(tokens, rotation=45)
151
+ ax.set_title("Integrated Gradients Attribution")
152
+ buf = io.BytesIO()
153
+ plt.savefig(buf, format='png')
154
+ buf.seek(0)
155
+ img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
156
+ plots.append(f"data:image/png;base64,{img_base64}")
157
+ plt.close()
158
+ except Exception as e:
159
+ logger.warning(f"Integrated Gradients failed: {e}")
160
+ dataframes.append({"Error": ["Attribution could not be computed."]})
161
+
162
+ return plots, dataframes, "Processing complete."
163
+
164
+ except Exception as e:
165
+ logger.error(f"Processing failed: {e}")
166
+ return [], [{"Error": [str(e)]}], f"Error: {e}"
167
 
168
  # Gradio Interface
169
  def create_gradio_interface():
170
  with gr.Blocks(title="Neural Network Visualization Demo") as demo:
171
  gr.Markdown("# Neural Network Visualization Demo")
172
+ gr.Markdown("Analyze BERT's neural network paths. Enter text, select a layer, and choose a visualization.")
173
 
174
  with gr.Row():
175
  with gr.Column():
176
  input_text = gr.Textbox(label="Input Text", value="The quick brown fox jumps over the lazy dog.")
177
  layer_name = gr.Dropdown(
178
  label="Select Layer",
179
+ choices=[str(name) for name, _ in model.named_modules() if 'layer' in name or 'embeddings' in name],
180
+ value="embeddings"
 
181
  )
182
  visualize_option = gr.Radio(
183
  label="Visualization Type",
 
185
  value="Embeddings"
186
  )
187
  attribution_target = gr.Slider(
188
+ label="Attribution Target Class (0 or 1)",
189
  minimum=0,
190
  maximum=1,
191
  step=1,
 
206
 
207
  return demo
208
 
209
+ # Launch the demo locally
210
  if __name__ == "__main__":
211
+ try:
212
+ demo = create_gradio_interface()
213
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
214
+ except Exception as e:
215
+ logger.error(f"Failed to launch Gradio demo: {e}")
216
+ print(f"Error launching demo: {e}. Try running locally without share=True.")