Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -165,23 +165,27 @@ def process_input(input_text, layer_name, visualize_option, attribution_target=0
|
|
165 |
dataframes.append({"Error": [f"Layer {layer_name} not found."]})
|
166 |
html_plots.append(f"<p>Error: Layer {layer_name} not found.</p>")
|
167 |
|
168 |
-
# Attribution: Integrated Gradients
|
169 |
-
def
|
170 |
-
|
171 |
-
|
|
|
|
|
|
|
|
|
172 |
return outputs.pooler_output[:, int(attribution_target)]
|
173 |
|
174 |
ig = IntegratedGradients(forward_func)
|
175 |
try:
|
176 |
-
#
|
177 |
-
|
178 |
attributions, _ = ig.attribute(
|
179 |
-
inputs=
|
180 |
additional_forward_args=(attention_mask,),
|
181 |
target=int(attribution_target),
|
182 |
return_convergence_delta=True
|
183 |
)
|
184 |
-
attr = attributions[0].detach().numpy().sum(axis=1)
|
185 |
attr_df = pd.DataFrame({"Token": tokens, "Attribution": attr})
|
186 |
attr_df.index = [f"idx_{i}" for i in range(len(attr_df))] # String indices
|
187 |
dataframes.append(convert_dict_keys_to_str(attr_df.to_dict()))
|
@@ -260,7 +264,7 @@ def create_gradio_interface():
|
|
260 |
if __name__ == "__main__":
|
261 |
try:
|
262 |
demo = create_gradio_interface()
|
263 |
-
demo.launch(server_name="0.0.0.0", server_port=
|
264 |
except Exception as e:
|
265 |
logger.error(f"Failed to launch Gradio demo: {e}")
|
266 |
print(f"Error launching demo: {e}. Try running locally with a different port or without share=True.")
|
|
|
165 |
dataframes.append({"Error": [f"Layer {layer_name} not found."]})
|
166 |
html_plots.append(f"<p>Error: Layer {layer_name} not found.</p>")
|
167 |
|
168 |
+
# Attribution: Integrated Gradients on embeddings
|
169 |
+
def get_embeddings(inputs, attention_mask=None):
|
170 |
+
with torch.no_grad():
|
171 |
+
embeddings = model.bert.embeddings(inputs) # Get float embeddings
|
172 |
+
return embeddings
|
173 |
+
|
174 |
+
def forward_func(embeddings, attention_mask=None):
|
175 |
+
outputs = model(inputs_embeds=embeddings, attention_mask=attention_mask)
|
176 |
return outputs.pooler_output[:, int(attribution_target)]
|
177 |
|
178 |
ig = IntegratedGradients(forward_func)
|
179 |
try:
|
180 |
+
# Get embeddings for input_ids
|
181 |
+
embeddings = get_embeddings(input_ids, attention_mask).requires_grad_(True)
|
182 |
attributions, _ = ig.attribute(
|
183 |
+
inputs=embeddings,
|
184 |
additional_forward_args=(attention_mask,),
|
185 |
target=int(attribution_target),
|
186 |
return_convergence_delta=True
|
187 |
)
|
188 |
+
attr = attributions[0].detach().numpy().sum(axis=1) # Sum over hidden size
|
189 |
attr_df = pd.DataFrame({"Token": tokens, "Attribution": attr})
|
190 |
attr_df.index = [f"idx_{i}" for i in range(len(attr_df))] # String indices
|
191 |
dataframes.append(convert_dict_keys_to_str(attr_df.to_dict()))
|
|
|
264 |
if __name__ == "__main__":
|
265 |
try:
|
266 |
demo = create_gradio_interface()
|
267 |
+
demo.launch(server_name="0.0.0.0", server_port=7861, share=False)
|
268 |
except Exception as e:
|
269 |
logger.error(f"Failed to launch Gradio demo: {e}")
|
270 |
print(f"Error launching demo: {e}. Try running locally with a different port or without share=True.")
|