broadfield-dev commited on
Commit
623d954
·
verified ·
1 Parent(s): 4d2190d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -9
app.py CHANGED
@@ -165,23 +165,27 @@ def process_input(input_text, layer_name, visualize_option, attribution_target=0
165
  dataframes.append({"Error": [f"Layer {layer_name} not found."]})
166
  html_plots.append(f"<p>Error: Layer {layer_name} not found.</p>")
167
 
168
- # Attribution: Integrated Gradients
169
- def forward_func(inputs, attention_mask=None):
170
- inputs = inputs.to(dtype=torch.long) # Ensure LongTensor
171
- outputs = model(inputs, attention_mask=attention_mask)
 
 
 
 
172
  return outputs.pooler_output[:, int(attribution_target)]
173
 
174
  ig = IntegratedGradients(forward_func)
175
  try:
176
- # Ensure input_ids is LongTensor and requires grad
177
- input_ids = input_ids.to(dtype=torch.long).requires_grad_(True)
178
  attributions, _ = ig.attribute(
179
- inputs=input_ids,
180
  additional_forward_args=(attention_mask,),
181
  target=int(attribution_target),
182
  return_convergence_delta=True
183
  )
184
- attr = attributions[0].detach().numpy().sum(axis=1)
185
  attr_df = pd.DataFrame({"Token": tokens, "Attribution": attr})
186
  attr_df.index = [f"idx_{i}" for i in range(len(attr_df))] # String indices
187
  dataframes.append(convert_dict_keys_to_str(attr_df.to_dict()))
@@ -260,7 +264,7 @@ def create_gradio_interface():
260
  if __name__ == "__main__":
261
  try:
262
  demo = create_gradio_interface()
263
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
264
  except Exception as e:
265
  logger.error(f"Failed to launch Gradio demo: {e}")
266
  print(f"Error launching demo: {e}. Try running locally with a different port or without share=True.")
 
165
  dataframes.append({"Error": [f"Layer {layer_name} not found."]})
166
  html_plots.append(f"<p>Error: Layer {layer_name} not found.</p>")
167
 
168
+ # Attribution: Integrated Gradients on embeddings
169
+ def get_embeddings(inputs, attention_mask=None):
170
+ with torch.no_grad():
171
+ embeddings = model.bert.embeddings(inputs) # Get float embeddings
172
+ return embeddings
173
+
174
+ def forward_func(embeddings, attention_mask=None):
175
+ outputs = model(inputs_embeds=embeddings, attention_mask=attention_mask)
176
  return outputs.pooler_output[:, int(attribution_target)]
177
 
178
  ig = IntegratedGradients(forward_func)
179
  try:
180
+ # Get embeddings for input_ids
181
+ embeddings = get_embeddings(input_ids, attention_mask).requires_grad_(True)
182
  attributions, _ = ig.attribute(
183
+ inputs=embeddings,
184
  additional_forward_args=(attention_mask,),
185
  target=int(attribution_target),
186
  return_convergence_delta=True
187
  )
188
+ attr = attributions[0].detach().numpy().sum(axis=1) # Sum over hidden size
189
  attr_df = pd.DataFrame({"Token": tokens, "Attribution": attr})
190
  attr_df.index = [f"idx_{i}" for i in range(len(attr_df))] # String indices
191
  dataframes.append(convert_dict_keys_to_str(attr_df.to_dict()))
 
264
  if __name__ == "__main__":
265
  try:
266
  demo = create_gradio_interface()
267
+ demo.launch(server_name="0.0.0.0", server_port=7861, share=False)
268
  except Exception as e:
269
  logger.error(f"Failed to launch Gradio demo: {e}")
270
  print(f"Error launching demo: {e}. Try running locally with a different port or without share=True.")