Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -34,6 +34,16 @@ for name, layer in model.named_modules():
|
|
34 |
if 'layer' in name or 'embeddings' in name:
|
35 |
layer.register_forward_hook(lambda m, i, o, n=name: hook_fn(m, i, o, n))
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
def process_input(input_text, layer_name, visualize_option, attribution_target=0):
|
38 |
"""
|
39 |
Process input text, compute embeddings, activations, attention, and attribution.
|
@@ -44,7 +54,7 @@ def process_input(input_text, layer_name, visualize_option, attribution_target=0
|
|
44 |
- attribution_target: Target class for attribution (0 or 1)
|
45 |
Returns:
|
46 |
- HTML string with base64-encoded image(s)
|
47 |
-
- List of dataframe dictionaries
|
48 |
- Status message
|
49 |
"""
|
50 |
global activations
|
@@ -94,10 +104,11 @@ def process_input(input_text, layer_name, visualize_option, attribution_target=0
|
|
94 |
# Dataframe for coordinates
|
95 |
dataframe = pd.DataFrame({
|
96 |
"Token": tokens,
|
97 |
-
"t-
|
98 |
-
"t-
|
99 |
-
})
|
100 |
-
|
|
|
101 |
except Exception as e:
|
102 |
logger.warning(f"t-SNE failed: {e}")
|
103 |
dataframes.append({"Error": [str(e)]})
|
@@ -122,8 +133,8 @@ def process_input(input_text, layer_name, visualize_option, attribution_target=0
|
|
122 |
html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="Attention Heatmap" style="max-width:100%;"/>')
|
123 |
plt.close()
|
124 |
# Dataframe for attention weights
|
125 |
-
dataframe = pd.DataFrame(attn, index=tokens, columns=tokens)
|
126 |
-
dataframes.append(dataframe)
|
127 |
else:
|
128 |
dataframes.append({"Error": ["No attention weights available."]})
|
129 |
html_plots.append("<p>Error: No attention weights available.</p>")
|
@@ -135,8 +146,8 @@ def process_input(input_text, layer_name, visualize_option, attribution_target=0
|
|
135 |
if isinstance(act, tuple):
|
136 |
act = act[0]
|
137 |
act = act[0].detach().numpy() # [seq_len, hidden_size]
|
138 |
-
dataframe = pd.DataFrame(act, index=tokens
|
139 |
-
dataframes.append(dataframe)
|
140 |
# Plot mean activation per token
|
141 |
fig, ax = plt.subplots(figsize=(8, 6))
|
142 |
mean_act = np.mean(act, axis=1)
|
@@ -168,8 +179,9 @@ def process_input(input_text, layer_name, visualize_option, attribution_target=0
|
|
168 |
return_convergence_delta=True
|
169 |
)
|
170 |
attr = attributions[0].detach().numpy().sum(axis=1)
|
171 |
-
attr_df = pd.DataFrame({"Token": tokens, "Attribution": attr})
|
172 |
-
|
|
|
173 |
# Plot attributions
|
174 |
fig, ax = plt.subplots(figsize=(8, 6))
|
175 |
ax.bar(range(len(attr)), attr)
|
@@ -180,7 +192,7 @@ def process_input(input_text, layer_name, visualize_option, attribution_target=0
|
|
180 |
plt.savefig(buf, format='png', bbox_inches='tight')
|
181 |
buf.seek(0)
|
182 |
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
|
183 |
-
html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="Attribution Plot" style="max-width:100
|
184 |
plt.close()
|
185 |
except Exception as e:
|
186 |
logger.warning(f"Integrated Gradients failed: {e}")
|
|
|
34 |
if 'layer' in name or 'embeddings' in name:
|
35 |
layer.register_forward_hook(lambda m, i, o, n=name: hook_fn(m, i, o, n))
|
36 |
|
37 |
+
def convert_dict_keys_to_str(d):
|
38 |
+
"""Recursively convert all dictionary keys to strings."""
|
39 |
+
if isinstance(d, dict):
|
40 |
+
return {str(k): convert_dict_keys_to_str(v) for k, v in d.items()}
|
41 |
+
elif isinstance(d, list):
|
42 |
+
return [convert_dict_keys_to_str(item) for item in d]
|
43 |
+
elif isinstance(d, np.ndarray):
|
44 |
+
return d.tolist() # Convert numpy arrays to lists
|
45 |
+
return d
|
46 |
+
|
47 |
def process_input(input_text, layer_name, visualize_option, attribution_target=0):
|
48 |
"""
|
49 |
Process input text, compute embeddings, activations, attention, and attribution.
|
|
|
54 |
- attribution_target: Target class for attribution (0 or 1)
|
55 |
Returns:
|
56 |
- HTML string with base64-encoded image(s)
|
57 |
+
- List of dataframe dictionaries with string keys
|
58 |
- Status message
|
59 |
"""
|
60 |
global activations
|
|
|
104 |
# Dataframe for coordinates
|
105 |
dataframe = pd.DataFrame({
|
106 |
"Token": tokens,
|
107 |
+
"t-SNE_X": reduced[:, 0],
|
108 |
+
"t-SNE_Y": reduced[:, 1]
|
109 |
+
})
|
110 |
+
dataframe.index = [f"idx_{i}" for i in range(len(dataframe))] # String indices
|
111 |
+
dataframes.append(convert_dict_keys_to_str(dataframe.to_dict()))
|
112 |
except Exception as e:
|
113 |
logger.warning(f"t-SNE failed: {e}")
|
114 |
dataframes.append({"Error": [str(e)]})
|
|
|
133 |
html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="Attention Heatmap" style="max-width:100%;"/>')
|
134 |
plt.close()
|
135 |
# Dataframe for attention weights
|
136 |
+
dataframe = pd.DataFrame(attn, index=tokens, columns=[f"token_{i}" for i in range(len(tokens))])
|
137 |
+
dataframes.append(convert_dict_keys_to_str(dataframe.to_dict()))
|
138 |
else:
|
139 |
dataframes.append({"Error": ["No attention weights available."]})
|
140 |
html_plots.append("<p>Error: No attention weights available.</p>")
|
|
|
146 |
if isinstance(act, tuple):
|
147 |
act = act[0]
|
148 |
act = act[0].detach().numpy() # [seq_len, hidden_size]
|
149 |
+
dataframe = pd.DataFrame(act, index=tokens, columns=[f"dim_{i}" for i in range(act.shape[1])])
|
150 |
+
dataframes.append(convert_dict_keys_to_str(dataframe.to_dict()))
|
151 |
# Plot mean activation per token
|
152 |
fig, ax = plt.subplots(figsize=(8, 6))
|
153 |
mean_act = np.mean(act, axis=1)
|
|
|
179 |
return_convergence_delta=True
|
180 |
)
|
181 |
attr = attributions[0].detach().numpy().sum(axis=1)
|
182 |
+
attr_df = pd.DataFrame({"Token": tokens, "Attribution": attr})
|
183 |
+
attr_df.index = [f"idx_{i}" for i in range(len(attr_df))] # String indices
|
184 |
+
dataframes.append(convert_dict_keys_to_str(attr_df.to_dict()))
|
185 |
# Plot attributions
|
186 |
fig, ax = plt.subplots(figsize=(8, 6))
|
187 |
ax.bar(range(len(attr)), attr)
|
|
|
192 |
plt.savefig(buf, format='png', bbox_inches='tight')
|
193 |
buf.seek(0)
|
194 |
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
|
195 |
+
html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="Attribution Plot" style="max-width:100?%"/>')
|
196 |
plt.close()
|
197 |
except Exception as e:
|
198 |
logger.warning(f"Integrated Gradients failed: {e}")
|