aksell commited on
Commit
db42d48
·
1 Parent(s): 53a7dc6

Calculate pct of total attention

Browse files
hexviz/attention.py CHANGED
@@ -299,6 +299,7 @@ def get_attention_pairs(
299
  top_n_values, top_n_indexes = torch.topk(attention_into_res, top_n)
300
 
301
  for res, attn_sum in zip(top_n_indexes, top_n_values):
302
- top_residues.append((attn_sum.item(), chain_ids[i], res.item()))
 
303
 
304
  return attention_pairs, top_residues
 
299
  top_n_values, top_n_indexes = torch.topk(attention_into_res, top_n)
300
 
301
  for res, attn_sum in zip(top_n_indexes, top_n_values):
302
+ fraction_of_total_attention = attn_sum.item() / len(sequence)
303
+ top_residues.append((fraction_of_total_attention, chain_ids[i], res.item()))
304
 
305
  return attention_pairs, top_residues
hexviz/🧬Attention_Visualization.py CHANGED
@@ -264,15 +264,19 @@ Pick a PDB ID, layer and head to visualize attention from the selected protein l
264
 
265
  chain_dict = {f"{chain.id}": list(chain.get_residues()) for chain in list(structure.get_chains())}
266
  data = []
267
- for att_weight, chain, resi in top_residues:
268
  try:
269
  res = chain_dict[chain][resi]
270
  except KeyError:
271
  continue
272
- el = (att_weight, f"{res.resname:3}{res.id[1]}({chain})")
 
273
  data.append(el)
274
 
275
- df = pd.DataFrame(data, columns=["Total attention to", "Residue"])
 
 
 
276
  st.markdown(
277
  f"The {n_highest_resis} residues (per chain) with the highest attention to them are labeled in the visualization and listed here:"
278
  )
 
264
 
265
  chain_dict = {f"{chain.id}": list(chain.get_residues()) for chain in list(structure.get_chains())}
266
  data = []
267
+ for fraction_of_total_attention, chain, resi in top_residues:
268
  try:
269
  res = chain_dict[chain][resi]
270
  except KeyError:
271
  continue
272
+ pct_of_total_attention = round(fraction_of_total_attention * 100, 3)
273
+ el = (pct_of_total_attention, f"{res.resname:3}{res.id[1]}({chain})")
274
  data.append(el)
275
 
276
+ df = pd.DataFrame(data, columns=["% of total attention", "Residue"])
277
+ df = df.style.format(
278
+ {"% of total attention": "{:.3f}"} # Set 3 decimal places for "% of total attention"
279
+ )
280
  st.markdown(
281
  f"The {n_highest_resis} residues (per chain) with the highest attention to them are labeled in the visualization and listed here:"
282
  )