Spaces:
Sleeping
Sleeping
Calculate pct of total attention
Browse files
hexviz/attention.py
CHANGED
@@ -299,6 +299,7 @@ def get_attention_pairs(
|
|
299 |
top_n_values, top_n_indexes = torch.topk(attention_into_res, top_n)
|
300 |
|
301 |
for res, attn_sum in zip(top_n_indexes, top_n_values):
|
302 |
-
|
|
|
303 |
|
304 |
return attention_pairs, top_residues
|
|
|
299 |
top_n_values, top_n_indexes = torch.topk(attention_into_res, top_n)
|
300 |
|
301 |
for res, attn_sum in zip(top_n_indexes, top_n_values):
|
302 |
+
fraction_of_total_attention = attn_sum.item() / len(sequence)
|
303 |
+
top_residues.append((fraction_of_total_attention, chain_ids[i], res.item()))
|
304 |
|
305 |
return attention_pairs, top_residues
|
hexviz/🧬Attention_Visualization.py
CHANGED
@@ -264,15 +264,19 @@ Pick a PDB ID, layer and head to visualize attention from the selected protein l
|
|
264 |
|
265 |
chain_dict = {f"{chain.id}": list(chain.get_residues()) for chain in list(structure.get_chains())}
|
266 |
data = []
|
267 |
-
for
|
268 |
try:
|
269 |
res = chain_dict[chain][resi]
|
270 |
except KeyError:
|
271 |
continue
|
272 |
-
|
|
|
273 |
data.append(el)
|
274 |
|
275 |
-
df = pd.DataFrame(data, columns=["
|
|
|
|
|
|
|
276 |
st.markdown(
|
277 |
f"The {n_highest_resis} residues (per chain) with the highest attention to them are labeled in the visualization and listed here:"
|
278 |
)
|
|
|
264 |
|
265 |
chain_dict = {f"{chain.id}": list(chain.get_residues()) for chain in list(structure.get_chains())}
|
266 |
data = []
|
267 |
+
for fraction_of_total_attention, chain, resi in top_residues:
|
268 |
try:
|
269 |
res = chain_dict[chain][resi]
|
270 |
except KeyError:
|
271 |
continue
|
272 |
+
pct_of_total_attention = round(fraction_of_total_attention * 100, 3)
|
273 |
+
el = (pct_of_total_attention, f"{res.resname:3}{res.id[1]}({chain})")
|
274 |
data.append(el)
|
275 |
|
276 |
+
df = pd.DataFrame(data, columns=["% of total attention", "Residue"])
|
277 |
+
df = df.style.format(
|
278 |
+
{"% of total attention": "{:.3f}"} # Set 3 decimal places for "% of total attention"
|
279 |
+
)
|
280 |
st.markdown(
|
281 |
f"The {n_highest_resis} residues (per chain) with the highest attention to them are labeled in the visualization and listed here:"
|
282 |
)
|