broadfield-dev commited on
Commit
4152a57
·
verified ·
1 Parent(s): 5ea20f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -15,10 +15,10 @@ import logging
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
- # Initialize BERT model and tokenizer
19
  try:
20
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
21
- model = BertModel.from_pretrained('bert-base-uncased')
22
  model.eval()
23
  except Exception as e:
24
  logger.error(f"Failed to load BERT model: {e}")
@@ -67,7 +67,7 @@ def process_input(input_text, layer_name, visualize_option, attribution_target=0
67
 
68
  # Tokenize input
69
  inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True, max_length=512)
70
- input_ids = inputs['input_ids']
71
  attention_mask = inputs['attention_mask']
72
 
73
  # Forward pass
@@ -167,11 +167,14 @@ def process_input(input_text, layer_name, visualize_option, attribution_target=0
167
 
168
  # Attribution: Integrated Gradients
169
  def forward_func(inputs, attention_mask=None):
 
170
  outputs = model(inputs, attention_mask=attention_mask)
171
  return outputs.pooler_output[:, int(attribution_target)]
172
 
173
  ig = IntegratedGradients(forward_func)
174
  try:
 
 
175
  attributions, _ = ig.attribute(
176
  inputs=input_ids,
177
  additional_forward_args=(attention_mask,),
@@ -192,7 +195,7 @@ def process_input(input_text, layer_name, visualize_option, attribution_target=0
192
  plt.savefig(buf, format='png', bbox_inches='tight')
193
  buf.seek(0)
194
  img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
195
- html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="Attribution Plot" style="max-width:100?%"/>')
196
  plt.close()
197
  except Exception as e:
198
  logger.warning(f"Integrated Gradients failed: {e}")
@@ -257,7 +260,7 @@ def create_gradio_interface():
257
  if __name__ == "__main__":
258
  try:
259
  demo = create_gradio_interface()
260
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
261
  except Exception as e:
262
  logger.error(f"Failed to launch Gradio demo: {e}")
263
  print(f"Error launching demo: {e}. Try running locally with a different port or without share=True.")
 
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
+ # Initialize BERT model and tokenizer with eager attention
19
  try:
20
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
21
+ model = BertModel.from_pretrained('bert-base-uncased', attn_implementation="eager")
22
  model.eval()
23
  except Exception as e:
24
  logger.error(f"Failed to load BERT model: {e}")
 
67
 
68
  # Tokenize input
69
  inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True, max_length=512)
70
+ input_ids = inputs['input_ids'].to(dtype=torch.long) # Ensure LongTensor
71
  attention_mask = inputs['attention_mask']
72
 
73
  # Forward pass
 
167
 
168
  # Attribution: Integrated Gradients
169
  def forward_func(inputs, attention_mask=None):
170
+ inputs = inputs.to(dtype=torch.long) # Ensure LongTensor
171
  outputs = model(inputs, attention_mask=attention_mask)
172
  return outputs.pooler_output[:, int(attribution_target)]
173
 
174
  ig = IntegratedGradients(forward_func)
175
  try:
176
+ # Ensure input_ids is LongTensor and requires grad
177
+ input_ids = input_ids.to(dtype=torch.long).requires_grad_(True)
178
  attributions, _ = ig.attribute(
179
  inputs=input_ids,
180
  additional_forward_args=(attention_mask,),
 
195
  plt.savefig(buf, format='png', bbox_inches='tight')
196
  buf.seek(0)
197
  img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
198
+ html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="Attribution Plot" style="max-width:100%;"/>')
199
  plt.close()
200
  except Exception as e:
201
  logger.warning(f"Integrated Gradients failed: {e}")
 
260
  if __name__ == "__main__":
261
  try:
262
  demo = create_gradio_interface()
263
+ demo.launch(server_name="0.0.0.0", server_port=7861, share=False)
264
  except Exception as e:
265
  logger.error(f"Failed to launch Gradio demo: {e}")
266
  print(f"Error launching demo: {e}. Try running locally with a different port or without share=True.")