Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -15,10 +15,10 @@ import logging
|
|
| 15 |
logging.basicConfig(level=logging.INFO)
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
| 18 |
-
# Initialize BERT model and tokenizer
|
| 19 |
try:
|
| 20 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 21 |
-
model = BertModel.from_pretrained('bert-base-uncased')
|
| 22 |
model.eval()
|
| 23 |
except Exception as e:
|
| 24 |
logger.error(f"Failed to load BERT model: {e}")
|
|
@@ -67,7 +67,7 @@ def process_input(input_text, layer_name, visualize_option, attribution_target=0
|
|
| 67 |
|
| 68 |
# Tokenize input
|
| 69 |
inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
| 70 |
-
input_ids = inputs['input_ids']
|
| 71 |
attention_mask = inputs['attention_mask']
|
| 72 |
|
| 73 |
# Forward pass
|
|
@@ -167,11 +167,14 @@ def process_input(input_text, layer_name, visualize_option, attribution_target=0
|
|
| 167 |
|
| 168 |
# Attribution: Integrated Gradients
|
| 169 |
def forward_func(inputs, attention_mask=None):
|
|
|
|
| 170 |
outputs = model(inputs, attention_mask=attention_mask)
|
| 171 |
return outputs.pooler_output[:, int(attribution_target)]
|
| 172 |
|
| 173 |
ig = IntegratedGradients(forward_func)
|
| 174 |
try:
|
|
|
|
|
|
|
| 175 |
attributions, _ = ig.attribute(
|
| 176 |
inputs=input_ids,
|
| 177 |
additional_forward_args=(attention_mask,),
|
|
@@ -192,7 +195,7 @@ def process_input(input_text, layer_name, visualize_option, attribution_target=0
|
|
| 192 |
plt.savefig(buf, format='png', bbox_inches='tight')
|
| 193 |
buf.seek(0)
|
| 194 |
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
|
| 195 |
-
html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="Attribution Plot" style="max-width:100
|
| 196 |
plt.close()
|
| 197 |
except Exception as e:
|
| 198 |
logger.warning(f"Integrated Gradients failed: {e}")
|
|
@@ -257,7 +260,7 @@ def create_gradio_interface():
|
|
| 257 |
if __name__ == "__main__":
|
| 258 |
try:
|
| 259 |
demo = create_gradio_interface()
|
| 260 |
-
demo.launch(server_name="0.0.0.0", server_port=
|
| 261 |
except Exception as e:
|
| 262 |
logger.error(f"Failed to launch Gradio demo: {e}")
|
| 263 |
print(f"Error launching demo: {e}. Try running locally with a different port or without share=True.")
|
|
|
|
| 15 |
logging.basicConfig(level=logging.INFO)
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
| 18 |
+
# Initialize BERT model and tokenizer with eager attention
|
| 19 |
try:
|
| 20 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 21 |
+
model = BertModel.from_pretrained('bert-base-uncased', attn_implementation="eager")
|
| 22 |
model.eval()
|
| 23 |
except Exception as e:
|
| 24 |
logger.error(f"Failed to load BERT model: {e}")
|
|
|
|
| 67 |
|
| 68 |
# Tokenize input
|
| 69 |
inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
| 70 |
+
input_ids = inputs['input_ids'].to(dtype=torch.long) # Ensure LongTensor
|
| 71 |
attention_mask = inputs['attention_mask']
|
| 72 |
|
| 73 |
# Forward pass
|
|
|
|
| 167 |
|
| 168 |
# Attribution: Integrated Gradients
|
| 169 |
def forward_func(inputs, attention_mask=None):
|
| 170 |
+
inputs = inputs.to(dtype=torch.long) # Ensure LongTensor
|
| 171 |
outputs = model(inputs, attention_mask=attention_mask)
|
| 172 |
return outputs.pooler_output[:, int(attribution_target)]
|
| 173 |
|
| 174 |
ig = IntegratedGradients(forward_func)
|
| 175 |
try:
|
| 176 |
+
# Ensure input_ids is LongTensor and requires grad
|
| 177 |
+
input_ids = input_ids.to(dtype=torch.long).requires_grad_(True)
|
| 178 |
attributions, _ = ig.attribute(
|
| 179 |
inputs=input_ids,
|
| 180 |
additional_forward_args=(attention_mask,),
|
|
|
|
| 195 |
plt.savefig(buf, format='png', bbox_inches='tight')
|
| 196 |
buf.seek(0)
|
| 197 |
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
|
| 198 |
+
html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="Attribution Plot" style="max-width:100%;"/>')
|
| 199 |
plt.close()
|
| 200 |
except Exception as e:
|
| 201 |
logger.warning(f"Integrated Gradients failed: {e}")
|
|
|
|
| 260 |
if __name__ == "__main__":
|
| 261 |
try:
|
| 262 |
demo = create_gradio_interface()
|
| 263 |
+
demo.launch(server_name="0.0.0.0", server_port=7861, share=False)
|
| 264 |
except Exception as e:
|
| 265 |
logger.error(f"Failed to launch Gradio demo: {e}")
|
| 266 |
print(f"Error launching demo: {e}. Try running locally with a different port or without share=True.")
|