Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,6 +6,7 @@ import matplotlib.pyplot as plt
|
|
6 |
from transformers import BertTokenizer, BertModel
|
7 |
from sklearn.manifold import TSNE
|
8 |
import seaborn as sns
|
|
|
9 |
import io
|
10 |
import base64
|
11 |
import logging
|
@@ -33,16 +34,18 @@ for name, layer in model.named_modules():
|
|
33 |
if 'layer' in name or 'embeddings' in name:
|
34 |
layer.register_forward_hook(lambda m, i, o, n=name: hook_fn(m, i, o, n))
|
35 |
|
36 |
-
def process_input(input_text, visualize_option):
|
37 |
"""
|
38 |
-
Process input text
|
39 |
Parameters:
|
40 |
- input_text: User-provided text
|
41 |
-
-
|
|
|
|
|
42 |
Returns:
|
43 |
-
-
|
44 |
-
-
|
45 |
-
- Status message
|
46 |
"""
|
47 |
global activations
|
48 |
activations = {} # Reset activations
|
@@ -50,7 +53,7 @@ def process_input(input_text, visualize_option):
|
|
50 |
try:
|
51 |
# Validate input
|
52 |
if not input_text.strip():
|
53 |
-
return
|
54 |
|
55 |
# Tokenize input
|
56 |
inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
@@ -67,8 +70,8 @@ def process_input(input_text, visualize_option):
|
|
67 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
68 |
|
69 |
# Initialize outputs
|
70 |
-
|
71 |
-
|
72 |
|
73 |
# Visualization: Embeddings (t-SNE)
|
74 |
if visualize_option == "Embeddings":
|
@@ -86,7 +89,7 @@ def process_input(input_text, visualize_option):
|
|
86 |
plt.savefig(buf, format='png', bbox_inches='tight')
|
87 |
buf.seek(0)
|
88 |
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
|
89 |
-
|
90 |
plt.close()
|
91 |
# Dataframe for coordinates
|
92 |
dataframe = pd.DataFrame({
|
@@ -94,13 +97,14 @@ def process_input(input_text, visualize_option):
|
|
94 |
"t-SNE X": reduced[:, 0],
|
95 |
"t-SNE Y": reduced[:, 1]
|
96 |
}).to_dict()
|
|
|
97 |
except Exception as e:
|
98 |
logger.warning(f"t-SNE failed: {e}")
|
99 |
-
|
100 |
-
|
101 |
else:
|
102 |
-
|
103 |
-
|
104 |
|
105 |
# Visualization: Attention Weights
|
106 |
elif visualize_option == "Attention":
|
@@ -115,25 +119,88 @@ def process_input(input_text, visualize_option):
|
|
115 |
plt.savefig(buf, format='png', bbox_inches='tight')
|
116 |
buf.seek(0)
|
117 |
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
|
118 |
-
|
119 |
plt.close()
|
120 |
# Dataframe for attention weights
|
121 |
dataframe = pd.DataFrame(attn, index=tokens, columns=tokens).to_dict()
|
|
|
122 |
else:
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
except Exception as e:
|
129 |
logger.error(f"Processing failed: {e}")
|
130 |
-
return
|
131 |
|
132 |
# Gradio Interface
|
133 |
def create_gradio_interface():
|
134 |
with gr.Blocks(title="Neural Network Visualization Demo") as demo:
|
135 |
gr.Markdown("# Neural Network Visualization Demo")
|
136 |
-
gr.Markdown("
|
137 |
|
138 |
with gr.Row():
|
139 |
with gr.Column():
|
@@ -142,21 +209,33 @@ def create_gradio_interface():
|
|
142 |
value="The quick brown fox jumps over the lazy dog.",
|
143 |
placeholder="Enter text here..."
|
144 |
)
|
|
|
|
|
|
|
|
|
|
|
145 |
visualize_option = gr.Radio(
|
146 |
label="Visualization Type",
|
147 |
-
choices=["Embeddings", "Attention"],
|
148 |
value="Embeddings"
|
149 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
submit_btn = gr.Button("Analyze")
|
151 |
|
152 |
with gr.Column():
|
153 |
-
plot_output = gr.
|
154 |
-
dataframe_output = gr.Dataframe(label="Data
|
155 |
text_output = gr.Textbox(label="Messages")
|
156 |
|
157 |
submit_btn.click(
|
158 |
fn=process_input,
|
159 |
-
inputs=[input_text, visualize_option],
|
160 |
outputs=[plot_output, dataframe_output, text_output]
|
161 |
)
|
162 |
|
|
|
6 |
from transformers import BertTokenizer, BertModel
|
7 |
from sklearn.manifold import TSNE
|
8 |
import seaborn as sns
|
9 |
+
from captum.attr import IntegratedGradients
|
10 |
import io
|
11 |
import base64
|
12 |
import logging
|
|
|
34 |
if 'layer' in name or 'embeddings' in name:
|
35 |
layer.register_forward_hook(lambda m, i, o, n=name: hook_fn(m, i, o, n))
|
36 |
|
37 |
+
def process_input(input_text, layer_name, visualize_option, attribution_target=0):
|
38 |
"""
|
39 |
+
Process input text, compute embeddings, activations, attention, and attribution.
|
40 |
Parameters:
|
41 |
- input_text: User-provided text
|
42 |
+
- layer_name: Selected layer for activation visualization
|
43 |
+
- visualize_option: 'Embeddings', 'Attention', or 'Activations'
|
44 |
+
- attribution_target: Target class for attribution (0 or 1)
|
45 |
Returns:
|
46 |
+
- HTML string with base64-encoded image(s)
|
47 |
+
- List of dataframe dictionaries
|
48 |
+
- Status message
|
49 |
"""
|
50 |
global activations
|
51 |
activations = {} # Reset activations
|
|
|
53 |
try:
|
54 |
# Validate input
|
55 |
if not input_text.strip():
|
56 |
+
return "<p>Error: Input text cannot be empty.</p>", [{"Error": ["Input text cannot be empty."]}], "Error: Input text cannot be empty."
|
57 |
|
58 |
# Tokenize input
|
59 |
inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
|
|
70 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
71 |
|
72 |
# Initialize outputs
|
73 |
+
html_plots = []
|
74 |
+
dataframes = []
|
75 |
|
76 |
# Visualization: Embeddings (t-SNE)
|
77 |
if visualize_option == "Embeddings":
|
|
|
89 |
plt.savefig(buf, format='png', bbox_inches='tight')
|
90 |
buf.seek(0)
|
91 |
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
|
92 |
+
html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="t-SNE Plot" style="max-width:100%;"/>')
|
93 |
plt.close()
|
94 |
# Dataframe for coordinates
|
95 |
dataframe = pd.DataFrame({
|
|
|
97 |
"t-SNE X": reduced[:, 0],
|
98 |
"t-SNE Y": reduced[:, 1]
|
99 |
}).to_dict()
|
100 |
+
dataframes.append(dataframe)
|
101 |
except Exception as e:
|
102 |
logger.warning(f"t-SNE failed: {e}")
|
103 |
+
dataframes.append({"Error": [str(e)]})
|
104 |
+
html_plots.append("<p>Error: t-SNE computation failed.</p>")
|
105 |
else:
|
106 |
+
dataframes.append({"Error": ["Too few tokens for t-SNE."]})
|
107 |
+
html_plots.append("<p>Error: Too few tokens for t-SNE.</p>")
|
108 |
|
109 |
# Visualization: Attention Weights
|
110 |
elif visualize_option == "Attention":
|
|
|
119 |
plt.savefig(buf, format='png', bbox_inches='tight')
|
120 |
buf.seek(0)
|
121 |
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
|
122 |
+
html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="Attention Heatmap" style="max-width:100%;"/>')
|
123 |
plt.close()
|
124 |
# Dataframe for attention weights
|
125 |
dataframe = pd.DataFrame(attn, index=tokens, columns=tokens).to_dict()
|
126 |
+
dataframes.append(dataframe)
|
127 |
else:
|
128 |
+
dataframes.append({"Error": ["No attention weights available."]})
|
129 |
+
html_plots.append("<p>Error: No attention weights available.</p>")
|
130 |
+
|
131 |
+
# Visualization: Activations
|
132 |
+
elif visualize_option == "Activations":
|
133 |
+
if layer_name in activations:
|
134 |
+
act = activations[layer_name]
|
135 |
+
if isinstance(act, tuple):
|
136 |
+
act = act[0]
|
137 |
+
act = act[0].detach().numpy() # [seq_len, hidden_size]
|
138 |
+
dataframe = pd.DataFrame(act, index=tokens).to_dict()
|
139 |
+
dataframes.append(dataframe)
|
140 |
+
# Plot mean activation per token
|
141 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
142 |
+
mean_act = np.mean(act, axis=1)
|
143 |
+
ax.bar(range(len(mean_act)), mean_act)
|
144 |
+
ax.set_xticks(range(len(mean_act)))
|
145 |
+
ax.set_xticklabels(tokens, rotation=45)
|
146 |
+
ax.set_title(f"Mean Activations in {layer_name}")
|
147 |
+
buf = io.BytesIO()
|
148 |
+
plt.savefig(buf, format='png', bbox_inches='tight')
|
149 |
+
buf.seek(0)
|
150 |
+
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
|
151 |
+
html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="Activations Plot" style="max-width:100%;"/>')
|
152 |
+
plt.close()
|
153 |
+
else:
|
154 |
+
dataframes.append({"Error": [f"Layer {layer_name} not found."]})
|
155 |
+
html_plots.append(f"<p>Error: Layer {layer_name} not found.</p>")
|
156 |
+
|
157 |
+
# Attribution: Integrated Gradients
|
158 |
+
def forward_func(inputs, attention_mask=None):
|
159 |
+
outputs = model(inputs, attention_mask=attention_mask)
|
160 |
+
return outputs.pooler_output[:, int(attribution_target)]
|
161 |
+
|
162 |
+
ig = IntegratedGradients(forward_func)
|
163 |
+
try:
|
164 |
+
attributions, _ = ig.attribute(
|
165 |
+
inputs=input_ids,
|
166 |
+
additional_forward_args=(attention_mask,),
|
167 |
+
target=int(attribution_target),
|
168 |
+
return_convergence_delta=True
|
169 |
+
)
|
170 |
+
attr = attributions[0].detach().numpy().sum(axis=1)
|
171 |
+
attr_df = pd.DataFrame({"Token": tokens, "Attribution": attr}).to_dict()
|
172 |
+
dataframes.append(attr_df)
|
173 |
+
# Plot attributions
|
174 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
175 |
+
ax.bar(range(len(attr)), attr)
|
176 |
+
ax.set_xticks(range(len(attr)))
|
177 |
+
ax.set_xticklabels(tokens, rotation=45)
|
178 |
+
ax.set_title("Integrated Gradients Attribution")
|
179 |
+
buf = io.BytesIO()
|
180 |
+
plt.savefig(buf, format='png', bbox_inches='tight')
|
181 |
+
buf.seek(0)
|
182 |
+
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
|
183 |
+
html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="Attribution Plot" style="max-width:100%;"/>')
|
184 |
+
plt.close()
|
185 |
+
except Exception as e:
|
186 |
+
logger.warning(f"Integrated Gradients failed: {e}")
|
187 |
+
dataframes.append({"Error": [str(e)]})
|
188 |
+
html_plots.append("<p>Error: Attribution computation failed.</p>")
|
189 |
+
|
190 |
+
# Combine HTML plots
|
191 |
+
html_output = "<div>" + "".join(html_plots) + "</div>"
|
192 |
+
|
193 |
+
return html_output, dataframes, "Processing complete."
|
194 |
|
195 |
except Exception as e:
|
196 |
logger.error(f"Processing failed: {e}")
|
197 |
+
return f"<p>Error: {e}</p>", [{"Error": [str(e)]}], f"Error: {e}"
|
198 |
|
199 |
# Gradio Interface
|
200 |
def create_gradio_interface():
|
201 |
with gr.Blocks(title="Neural Network Visualization Demo") as demo:
|
202 |
gr.Markdown("# Neural Network Visualization Demo")
|
203 |
+
gr.Markdown("Analyze BERT's neural network paths. Enter text, select a layer, and choose a visualization.")
|
204 |
|
205 |
with gr.Row():
|
206 |
with gr.Column():
|
|
|
209 |
value="The quick brown fox jumps over the lazy dog.",
|
210 |
placeholder="Enter text here..."
|
211 |
)
|
212 |
+
layer_name = gr.Dropdown(
|
213 |
+
label="Select Layer",
|
214 |
+
choices=[str(name) for name, _ in model.named_modules() if 'layer' in name or 'embeddings' in name],
|
215 |
+
value="embeddings"
|
216 |
+
)
|
217 |
visualize_option = gr.Radio(
|
218 |
label="Visualization Type",
|
219 |
+
choices=["Embeddings", "Attention", "Activations"],
|
220 |
value="Embeddings"
|
221 |
)
|
222 |
+
attribution_target = gr.Slider(
|
223 |
+
label="Attribution Target Class (0 or 1)",
|
224 |
+
minimum=0,
|
225 |
+
maximum=1,
|
226 |
+
step=1,
|
227 |
+
value=0
|
228 |
+
)
|
229 |
submit_btn = gr.Button("Analyze")
|
230 |
|
231 |
with gr.Column():
|
232 |
+
plot_output = gr.HTML(label="Visualizations")
|
233 |
+
dataframe_output = gr.Dataframe(label="Data Outputs")
|
234 |
text_output = gr.Textbox(label="Messages")
|
235 |
|
236 |
submit_btn.click(
|
237 |
fn=process_input,
|
238 |
+
inputs=[input_text, layer_name, visualize_option, attribution_target],
|
239 |
outputs=[plot_output, dataframe_output, text_output]
|
240 |
)
|
241 |
|