Spaces:
Sleeping
Sleeping
Calculate top residue using attention to residue
Browse filesUsing attention from the residues as well does not makes much sense
as the attention from the residues will always sum to 1, so
it will only be adding 1 to all of the residues total attention.
hexviz/attention.py
CHANGED
@@ -292,11 +292,13 @@ def get_attention_pairs(
|
|
292 |
residue_attention[res - ec_tag_length] = (
|
293 |
residue_attention.get(res - ec_tag_length, 0) + attn_value
|
294 |
)
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
|
|
|
|
301 |
|
302 |
return attention_pairs, top_residues
|
|
|
292 |
residue_attention[res - ec_tag_length] = (
|
293 |
residue_attention.get(res - ec_tag_length, 0) + attn_value
|
294 |
)
|
295 |
+
if not ec_number:
|
296 |
+
attention_into_res = attention[head, layer].sum(dim=0)
|
297 |
+
else:
|
298 |
+
attention_into_res = attention[head, layer, ec_tag_length:, ec_tag_length:].sum(dim=0)
|
299 |
+
top_n_values, top_n_indexes = torch.topk(attention_into_res, top_n)
|
300 |
+
|
301 |
+
for res, attn_sum in zip(top_n_indexes, top_n_values):
|
302 |
+
top_residues.append((attn_sum.item(), chain_ids[i], res.item()))
|
303 |
|
304 |
return attention_pairs, top_residues
|
hexviz/🧬Attention_Visualization.py
CHANGED
@@ -71,7 +71,6 @@ n_highest_resis = st.sidebar.number_input(
|
|
71 |
)
|
72 |
label_highest = st.sidebar.checkbox("Label highest attention residues", value=True)
|
73 |
sidechain_highest = st.sidebar.checkbox("Show sidechains", value=True)
|
74 |
-
# TODO add avg or max attention as params
|
75 |
|
76 |
|
77 |
with st.sidebar.expander("Label residues manually"):
|
@@ -238,7 +237,7 @@ def get_3dview(pdb):
|
|
238 |
)
|
239 |
|
240 |
if label_highest:
|
241 |
-
for _,
|
242 |
one_indexed_res = res + 1
|
243 |
xyzview.addResLabels(
|
244 |
{"chain": chain, "resi": one_indexed_res},
|
@@ -265,7 +264,7 @@ Pick a PDB ID, layer and head to visualize attention from the selected protein l
|
|
265 |
|
266 |
chain_dict = {f"{chain.id}": list(chain.get_residues()) for chain in list(structure.get_chains())}
|
267 |
data = []
|
268 |
-
for att_weight,
|
269 |
try:
|
270 |
res = chain_dict[chain][resi]
|
271 |
except KeyError:
|
@@ -273,9 +272,9 @@ for att_weight, _, chain, resi in top_residues:
|
|
273 |
el = (att_weight, f"{res.resname:3}{res.id[1]}({chain})")
|
274 |
data.append(el)
|
275 |
|
276 |
-
df = pd.DataFrame(data, columns=["Total attention
|
277 |
st.markdown(
|
278 |
-
f"The {n_highest_resis} residues (per chain) with the highest attention
|
279 |
)
|
280 |
st.table(df)
|
281 |
|
|
|
71 |
)
|
72 |
label_highest = st.sidebar.checkbox("Label highest attention residues", value=True)
|
73 |
sidechain_highest = st.sidebar.checkbox("Show sidechains", value=True)
|
|
|
74 |
|
75 |
|
76 |
with st.sidebar.expander("Label residues manually"):
|
|
|
237 |
)
|
238 |
|
239 |
if label_highest:
|
240 |
+
for _, chain, res in top_residues:
|
241 |
one_indexed_res = res + 1
|
242 |
xyzview.addResLabels(
|
243 |
{"chain": chain, "resi": one_indexed_res},
|
|
|
264 |
|
265 |
chain_dict = {f"{chain.id}": list(chain.get_residues()) for chain in list(structure.get_chains())}
|
266 |
data = []
|
267 |
+
for att_weight, chain, resi in top_residues:
|
268 |
try:
|
269 |
res = chain_dict[chain][resi]
|
270 |
except KeyError:
|
|
|
272 |
el = (att_weight, f"{res.resname:3}{res.id[1]}({chain})")
|
273 |
data.append(el)
|
274 |
|
275 |
+
df = pd.DataFrame(data, columns=["Total attention to", "Residue"])
|
276 |
st.markdown(
|
277 |
+
f"The {n_highest_resis} residues (per chain) with the highest attention to them are labeled in the visualization and listed here:"
|
278 |
)
|
279 |
st.table(df)
|
280 |
|