broadfield-dev commited on
Commit
5a6bc13
·
verified ·
1 Parent(s): e248142

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +231 -0
app.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import pandas as pd
5
+ import matplotlib.pyplot as plt
6
+ from transformers import BertTokenizer, BertModel
7
+ from sklearn.manifold import TSNE
8
+ import seaborn as sns
9
+ from captum.attr import IntegratedGradients
10
+ import io
11
+ import base64
12
+ from PIL import Image
13
+
14
+ # Initialize BERT model and tokenizer
15
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
16
+ model = BertModel.from_pretrained('bert-base-uncased')
17
+ model.eval()
18
+
19
+ # Alternative MLP model (uncomment to use instead of BERT)
20
+ """
21
+ # import torch.nn as nn
22
+ # class SimpleMLP(nn.Module):
23
+ # def __init__(self, input_size=10, hidden_sizes=[64, 32], output_size=2):
24
+ # super(SimpleMLP, self).__init__()
25
+ # layers = []
26
+ # prev_size = input_size
27
+ # for hidden_size in hidden_sizes:
28
+ # layers.append(nn.Linear(prev_size, hidden_size))
29
+ # layers.append(nn.ReLU())
30
+ # prev_size = hidden_size
31
+ # layers.append(nn.Linear(prev_size, output_size))
32
+ # self.network = nn.Sequential(*layers)
33
+ # def forward(self, x):
34
+ # return self.network(x)
35
+ # model = SimpleMLP()
36
+ # model.eval()
37
+ """
38
+
39
+ # Store intermediate activations
40
+ activations = {}
41
+ def hook_fn(module, input, output, name):
42
+ activations[name] = output
43
+
44
+ # Register hooks for BERT layers (or MLP layers)
45
+ for name, layer in model.named_modules():
46
+ if 'layer' in name or 'embeddings' in name: # Focus on transformer layers
47
+ layer.register_forward_hook(lambda m, i, o, n=name: hook_fn(m, i, o, n))
48
+ # For MLP, replace with:
49
+ # if isinstance(layer, nn.Linear) or isinstance(layer, nn.ReLU):
50
+ # layer.register_forward_hook(lambda m, i, o, n=name: hook_fn(m, i, o, n))
51
+
52
+ def process_input(input_text, layer_name, visualize_option, attribution_target=0):
53
+ """
54
+ Process input text, compute embeddings, activations, and visualizations.
55
+ Parameters:
56
+ - input_text: User-provided text input
57
+ - layer_name: Selected layer for visualization
58
+ - visualize_option: 'Embeddings', 'Attention', or 'Activations'
59
+ - attribution_target: Target class for attribution (0 or 1 for binary classification)
60
+ Returns:
61
+ - Dictionary with plots and dataframes
62
+ """
63
+ global activations
64
+ activations = {} # Reset activations
65
+
66
+ # Tokenize input
67
+ inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True, max_length=512)
68
+ input_ids = inputs['input_ids']
69
+ attention_mask = inputs['attention_mask']
70
+
71
+ # Forward pass
72
+ with torch.no_grad():
73
+ outputs = model(input_ids, attention_mask=attention_mask, output_attentions=True, output_hidden_states=True)
74
+ embeddings = outputs.last_hidden_state # [batch, seq_len, hidden_size]
75
+ attentions = outputs.attentions # List of attention weights
76
+ hidden_states = outputs.hidden_states # List of hidden states
77
+
78
+ # Convert token IDs to tokens
79
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
80
+
81
+ # Initialize output dictionary
82
+ results = {
83
+ "plots": [],
84
+ "dataframes": [],
85
+ "text": []
86
+ }
87
+
88
+ # Visualization: Embeddings (t-SNE)
89
+ if visualize_option == "Embeddings":
90
+ emb = embeddings[0].detach().numpy() # [seq_len, hidden_size]
91
+ if emb.shape[0] > 1: # Need at least 2 points for t-SNE
92
+ tsne = TSNE(n_components=2, random_state=42, perplexity=min(5, emb.shape[0]-1))
93
+ reduced = tsne.fit_transform(emb)
94
+ fig, ax = plt.subplots()
95
+ scatter = ax.scatter(reduced[:, 0], reduced[:, 1], c='blue')
96
+ for i, token in enumerate(tokens):
97
+ ax.annotate(token, (reduced[i, 0], reduced[i, 1]))
98
+ ax.set_title("t-SNE of Token Embeddings")
99
+ # Convert plot to base64 for Gradio
100
+ buf = io.BytesIO()
101
+ plt.savefig(buf, format='png')
102
+ buf.seek(0)
103
+ img = Image.open(buf)
104
+ img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
105
+ results["plots"].append(f"data:image/png;base64,{img_base64}")
106
+ plt.close()
107
+
108
+ # Visualization: Attention Weights
109
+ if visualize_option == "Attention":
110
+ if attentions:
111
+ attn = attentions[-1][0, 0].detach().numpy() # Last layer, first head
112
+ fig, ax = plt.subplots()
113
+ sns.heatmap(attn, xticklabels=tokens, yticklabels=tokens, cmap='viridis', ax=ax)
114
+ ax.set_title("Attention Weights (Last Layer, Head 0)")
115
+ plt.xticks(rotation=45)
116
+ plt.yticks(rotation=0)
117
+ # Convert plot to base64
118
+ buf = io.BytesIO()
119
+ plt.savefig(buf, format='png')
120
+ buf.seek(0)
121
+ img = Image.open(buf)
122
+ img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
123
+ results["plots"].append(f"data:image/png;base64,{img_base64}")
124
+ plt.close()
125
+
126
+ # Visualization: Activations
127
+ if visualize_option == "Activations":
128
+ if layer_name in activations:
129
+ act = activations[layer_name]
130
+ if isinstance(act, tuple): # Handle attention outputs
131
+ act = act[0]
132
+ act = act[0].detach().numpy() # [seq_len, hidden_size]
133
+ df = pd.DataFrame(act, index=tokens)
134
+ results["dataframes"].append(df)
135
+ # Plot mean activation per token
136
+ fig, ax = plt.subplots()
137
+ mean_act = np.mean(act, axis=1)
138
+ ax.bar(range(len(mean_act)), mean_act)
139
+ ax.set_xticks(range(len(mean_act)))
140
+ ax.set_xticklabels(tokens, rotation=45)
141
+ ax.set_title(f"Mean Activations in {layer_name}")
142
+ # Convert plot to base64
143
+ buf = io.BytesIO()
144
+ plt.savefig(buf, format='png')
145
+ buf.seek(0)
146
+ img = Image.open(buf)
147
+ img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
148
+ results["plots"].append(f"data:image/png;base64,{img_base64}")
149
+ plt.close()
150
+
151
+ # Attribution: Integrated Gradients
152
+ def forward_func(inputs, attention_mask=None):
153
+ outputs = model(inputs, attention_mask=attention_mask)
154
+ return outputs.pooler_output[:, attribution_target]
155
+
156
+ ig = IntegratedGradients(forward_func)
157
+ attributions, delta = ig.attribute(
158
+ inputs=input_ids,
159
+ additional_forward_args=(attention_mask,),
160
+ target=attribution_target,
161
+ return_convergence_delta=True
162
+ )
163
+ attr = attributions[0].detach().numpy()
164
+ attr_df = pd.DataFrame({"Token": tokens, "Attribution": attr.sum(axis=1)})
165
+ results["dataframes"].append(attr_df)
166
+
167
+ # Plot attributions
168
+ fig, ax = plt.subplots()
169
+ ax.bar(range(len(attr_df)), attr_df["Attribution"])
170
+ ax.set_xticks(range(len(attr_df)))
171
+ ax.set_xticklabels(tokens, rotation=45)
172
+ ax.set_title("Integrated Gradients Attribution")
173
+ buf = io.BytesIO()
174
+ plt.savefig(buf, format='png')
175
+ buf.seek(0)
176
+ img = Image.open(buf)
177
+ img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
178
+ results["plots"].append(f"data:image/png;base64,{img_base64}")
179
+ plt.close()
180
+
181
+ return (
182
+ results["plots"] if results["plots"] else None,
183
+ results["dataframes"] if results["dataframes"] else None,
184
+ "\n".join(results["text"]) if results["text"] else "Processing complete."
185
+ )
186
+
187
+ # Gradio Interface
188
+ def create_gradio_interface():
189
+ with gr.Blocks(title="Neural Network Visualization Demo") as demo:
190
+ gr.Markdown("# Neural Network Visualization Demo")
191
+ gr.Markdown("Analyze the paths of a BERT model from input to output. Enter text, select a layer, and choose a visualization option.")
192
+
193
+ with gr.Row():
194
+ with gr.Column():
195
+ input_text = gr.Textbox(label="Input Text", value="The quick brown fox jumps over the lazy dog.")
196
+ layer_name = gr.Dropdown(
197
+ label="Select Layer",
198
+ choices=[name for name in model.named_modules() if 'layer' in name or 'embeddings' in name],
199
+ value="embeddings"
200
+ )
201
+ visualize_option = gr.Radio(
202
+ label="Visualization Type",
203
+ choices=["Embeddings", "Attention", "Activations"],
204
+ value="Embeddings"
205
+ )
206
+ attribution_target = gr.Slider(
207
+ label="Attribution Target Class (0 or 1 for binary classification)",
208
+ minimum=0,
209
+ maximum=1,
210
+ step=1,
211
+ value=0
212
+ )
213
+ submit_btn = gr.Button("Analyze")
214
+
215
+ with gr.Column():
216
+ plot_output = gr.Gallery(label="Visualizations")
217
+ dataframe_output = gr.Dataframe(label="Data Outputs")
218
+ text_output = gr.Textbox(label="Messages")
219
+
220
+ submit_btn.click(
221
+ fn=process_input,
222
+ inputs=[input_text, layer_name, visualize_option, attribution_target],
223
+ outputs=[plot_output, dataframe_output, text_output]
224
+ )
225
+
226
+ return demo
227
+
228
+ # Launch the demo
229
+ if __name__ == "__main__":
230
+ demo = create_gradio_interface()
231
+ demo.launch()