aksell commited on
Commit
725c921
·
1 Parent(s): 0a6b613

Calculate top residue using attention to residue

Browse files

Using attention from the residues as well does not makes much sense
as the attention from the residues will always sum to 1, so
it will only be adding 1 to all of the residues total attention.

hexviz/attention.py CHANGED
@@ -292,11 +292,13 @@ def get_attention_pairs(
292
  residue_attention[res - ec_tag_length] = (
293
  residue_attention.get(res - ec_tag_length, 0) + attn_value
294
  )
295
-
296
- top_n_residues = sorted(residue_attention.items(), key=lambda x: x[1], reverse=True)[:top_n]
297
-
298
- for res, attn_sum in top_n_residues:
299
- coord = chain[res]["CA"].coord.tolist()
300
- top_residues.append((attn_sum, coord, chain_ids[i], res))
 
 
301
 
302
  return attention_pairs, top_residues
 
292
  residue_attention[res - ec_tag_length] = (
293
  residue_attention.get(res - ec_tag_length, 0) + attn_value
294
  )
295
+ if not ec_number:
296
+ attention_into_res = attention[head, layer].sum(dim=0)
297
+ else:
298
+ attention_into_res = attention[head, layer, ec_tag_length:, ec_tag_length:].sum(dim=0)
299
+ top_n_values, top_n_indexes = torch.topk(attention_into_res, top_n)
300
+
301
+ for res, attn_sum in zip(top_n_indexes, top_n_values):
302
+ top_residues.append((attn_sum.item(), chain_ids[i], res.item()))
303
 
304
  return attention_pairs, top_residues
hexviz/🧬Attention_Visualization.py CHANGED
@@ -71,7 +71,6 @@ n_highest_resis = st.sidebar.number_input(
71
  )
72
  label_highest = st.sidebar.checkbox("Label highest attention residues", value=True)
73
  sidechain_highest = st.sidebar.checkbox("Show sidechains", value=True)
74
- # TODO add avg or max attention as params
75
 
76
 
77
  with st.sidebar.expander("Label residues manually"):
@@ -238,7 +237,7 @@ def get_3dview(pdb):
238
  )
239
 
240
  if label_highest:
241
- for _, _, chain, res in top_residues:
242
  one_indexed_res = res + 1
243
  xyzview.addResLabels(
244
  {"chain": chain, "resi": one_indexed_res},
@@ -265,7 +264,7 @@ Pick a PDB ID, layer and head to visualize attention from the selected protein l
265
 
266
  chain_dict = {f"{chain.id}": list(chain.get_residues()) for chain in list(structure.get_chains())}
267
  data = []
268
- for att_weight, _, chain, resi in top_residues:
269
  try:
270
  res = chain_dict[chain][resi]
271
  except KeyError:
@@ -273,9 +272,9 @@ for att_weight, _, chain, resi in top_residues:
273
  el = (att_weight, f"{res.resname:3}{res.id[1]}({chain})")
274
  data.append(el)
275
 
276
- df = pd.DataFrame(data, columns=["Total attention (disregarding direction)", "Residue"])
277
  st.markdown(
278
- f"The {n_highest_resis} residues (per chain) with the highest attention sums are labeled in the visualization and listed here:"
279
  )
280
  st.table(df)
281
 
 
71
  )
72
  label_highest = st.sidebar.checkbox("Label highest attention residues", value=True)
73
  sidechain_highest = st.sidebar.checkbox("Show sidechains", value=True)
 
74
 
75
 
76
  with st.sidebar.expander("Label residues manually"):
 
237
  )
238
 
239
  if label_highest:
240
+ for _, chain, res in top_residues:
241
  one_indexed_res = res + 1
242
  xyzview.addResLabels(
243
  {"chain": chain, "resi": one_indexed_res},
 
264
 
265
  chain_dict = {f"{chain.id}": list(chain.get_residues()) for chain in list(structure.get_chains())}
266
  data = []
267
+ for att_weight, chain, resi in top_residues:
268
  try:
269
  res = chain_dict[chain][resi]
270
  except KeyError:
 
272
  el = (att_weight, f"{res.resname:3}{res.id[1]}({chain})")
273
  data.append(el)
274
 
275
+ df = pd.DataFrame(data, columns=["Total attention to", "Residue"])
276
  st.markdown(
277
+ f"The {n_highest_resis} residues (per chain) with the highest attention to them are labeled in the visualization and listed here:"
278
  )
279
  st.table(df)
280