broadfield-dev commited on
Commit
5ea20f7
·
verified ·
1 Parent(s): 8d87543

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -12
app.py CHANGED
@@ -34,6 +34,16 @@ for name, layer in model.named_modules():
34
  if 'layer' in name or 'embeddings' in name:
35
  layer.register_forward_hook(lambda m, i, o, n=name: hook_fn(m, i, o, n))
36
 
 
 
 
 
 
 
 
 
 
 
37
  def process_input(input_text, layer_name, visualize_option, attribution_target=0):
38
  """
39
  Process input text, compute embeddings, activations, attention, and attribution.
@@ -44,7 +54,7 @@ def process_input(input_text, layer_name, visualize_option, attribution_target=0
44
  - attribution_target: Target class for attribution (0 or 1)
45
  Returns:
46
  - HTML string with base64-encoded image(s)
47
- - List of dataframe dictionaries
48
  - Status message
49
  """
50
  global activations
@@ -94,10 +104,11 @@ def process_input(input_text, layer_name, visualize_option, attribution_target=0
94
  # Dataframe for coordinates
95
  dataframe = pd.DataFrame({
96
  "Token": tokens,
97
- "t-SNE X": reduced[:, 0],
98
- "t-SNE Y": reduced[:, 1]
99
- }).to_dict()
100
- dataframes.append(dataframe)
 
101
  except Exception as e:
102
  logger.warning(f"t-SNE failed: {e}")
103
  dataframes.append({"Error": [str(e)]})
@@ -122,8 +133,8 @@ def process_input(input_text, layer_name, visualize_option, attribution_target=0
122
  html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="Attention Heatmap" style="max-width:100%;"/>')
123
  plt.close()
124
  # Dataframe for attention weights
125
- dataframe = pd.DataFrame(attn, index=tokens, columns=tokens).to_dict()
126
- dataframes.append(dataframe)
127
  else:
128
  dataframes.append({"Error": ["No attention weights available."]})
129
  html_plots.append("<p>Error: No attention weights available.</p>")
@@ -135,8 +146,8 @@ def process_input(input_text, layer_name, visualize_option, attribution_target=0
135
  if isinstance(act, tuple):
136
  act = act[0]
137
  act = act[0].detach().numpy() # [seq_len, hidden_size]
138
- dataframe = pd.DataFrame(act, index=tokens).to_dict()
139
- dataframes.append(dataframe)
140
  # Plot mean activation per token
141
  fig, ax = plt.subplots(figsize=(8, 6))
142
  mean_act = np.mean(act, axis=1)
@@ -168,8 +179,9 @@ def process_input(input_text, layer_name, visualize_option, attribution_target=0
168
  return_convergence_delta=True
169
  )
170
  attr = attributions[0].detach().numpy().sum(axis=1)
171
- attr_df = pd.DataFrame({"Token": tokens, "Attribution": attr}).to_dict()
172
- dataframes.append(attr_df)
 
173
  # Plot attributions
174
  fig, ax = plt.subplots(figsize=(8, 6))
175
  ax.bar(range(len(attr)), attr)
@@ -180,7 +192,7 @@ def process_input(input_text, layer_name, visualize_option, attribution_target=0
180
  plt.savefig(buf, format='png', bbox_inches='tight')
181
  buf.seek(0)
182
  img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
183
- html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="Attribution Plot" style="max-width:100%;"/>')
184
  plt.close()
185
  except Exception as e:
186
  logger.warning(f"Integrated Gradients failed: {e}")
 
34
  if 'layer' in name or 'embeddings' in name:
35
  layer.register_forward_hook(lambda m, i, o, n=name: hook_fn(m, i, o, n))
36
 
37
+ def convert_dict_keys_to_str(d):
38
+ """Recursively convert all dictionary keys to strings."""
39
+ if isinstance(d, dict):
40
+ return {str(k): convert_dict_keys_to_str(v) for k, v in d.items()}
41
+ elif isinstance(d, list):
42
+ return [convert_dict_keys_to_str(item) for item in d]
43
+ elif isinstance(d, np.ndarray):
44
+ return d.tolist() # Convert numpy arrays to lists
45
+ return d
46
+
47
  def process_input(input_text, layer_name, visualize_option, attribution_target=0):
48
  """
49
  Process input text, compute embeddings, activations, attention, and attribution.
 
54
  - attribution_target: Target class for attribution (0 or 1)
55
  Returns:
56
  - HTML string with base64-encoded image(s)
57
+ - List of dataframe dictionaries with string keys
58
  - Status message
59
  """
60
  global activations
 
104
  # Dataframe for coordinates
105
  dataframe = pd.DataFrame({
106
  "Token": tokens,
107
+ "t-SNE_X": reduced[:, 0],
108
+ "t-SNE_Y": reduced[:, 1]
109
+ })
110
+ dataframe.index = [f"idx_{i}" for i in range(len(dataframe))] # String indices
111
+ dataframes.append(convert_dict_keys_to_str(dataframe.to_dict()))
112
  except Exception as e:
113
  logger.warning(f"t-SNE failed: {e}")
114
  dataframes.append({"Error": [str(e)]})
 
133
  html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="Attention Heatmap" style="max-width:100%;"/>')
134
  plt.close()
135
  # Dataframe for attention weights
136
+ dataframe = pd.DataFrame(attn, index=tokens, columns=[f"token_{i}" for i in range(len(tokens))])
137
+ dataframes.append(convert_dict_keys_to_str(dataframe.to_dict()))
138
  else:
139
  dataframes.append({"Error": ["No attention weights available."]})
140
  html_plots.append("<p>Error: No attention weights available.</p>")
 
146
  if isinstance(act, tuple):
147
  act = act[0]
148
  act = act[0].detach().numpy() # [seq_len, hidden_size]
149
+ dataframe = pd.DataFrame(act, index=tokens, columns=[f"dim_{i}" for i in range(act.shape[1])])
150
+ dataframes.append(convert_dict_keys_to_str(dataframe.to_dict()))
151
  # Plot mean activation per token
152
  fig, ax = plt.subplots(figsize=(8, 6))
153
  mean_act = np.mean(act, axis=1)
 
179
  return_convergence_delta=True
180
  )
181
  attr = attributions[0].detach().numpy().sum(axis=1)
182
+ attr_df = pd.DataFrame({"Token": tokens, "Attribution": attr})
183
+ attr_df.index = [f"idx_{i}" for i in range(len(attr_df))] # String indices
184
+ dataframes.append(convert_dict_keys_to_str(attr_df.to_dict()))
185
  # Plot attributions
186
  fig, ax = plt.subplots(figsize=(8, 6))
187
  ax.bar(range(len(attr)), attr)
 
192
  plt.savefig(buf, format='png', bbox_inches='tight')
193
  buf.seek(0)
194
  img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
195
+ html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="Attribution Plot" style="max-width:100?%"/>')
196
  plt.close()
197
  except Exception as e:
198
  logger.warning(f"Integrated Gradients failed: {e}")