aksell commited on
Commit
53a7dc6
·
1 Parent(s): 725c921

Fix bug in column vise attention calc

Browse files
hexviz/attention.py CHANGED
@@ -293,9 +293,9 @@ def get_attention_pairs(
293
  residue_attention.get(res - ec_tag_length, 0) + attn_value
294
  )
295
  if not ec_number:
296
- attention_into_res = attention[head, layer].sum(dim=0)
297
  else:
298
- attention_into_res = attention[head, layer, ec_tag_length:, ec_tag_length:].sum(dim=0)
299
  top_n_values, top_n_indexes = torch.topk(attention_into_res, top_n)
300
 
301
  for res, attn_sum in zip(top_n_indexes, top_n_values):
 
293
  residue_attention.get(res - ec_tag_length, 0) + attn_value
294
  )
295
  if not ec_number:
296
+ attention_into_res = attention[layer, head].sum(dim=0)
297
  else:
298
+ attention_into_res = attention[layer, head, ec_tag_length:, ec_tag_length:].sum(dim=0)
299
  top_n_values, top_n_indexes = torch.topk(attention_into_res, top_n)
300
 
301
  for res, attn_sum in zip(top_n_indexes, top_n_values):
hexviz/🧬Attention_Visualization.py CHANGED
@@ -28,7 +28,7 @@ models = [
28
  ]
29
 
30
  with st.expander("Input a PDB id, upload a PDB file or input a sequence", expanded=True):
31
- pdb_id = select_pdb() or "2WK4"
32
  uploaded_file = st.file_uploader("2.Upload PDB", type=["pdb"])
33
  input_sequence = st.text_area("3.Input sequence", "", key="input_sequence", max_chars=400)
34
  sequence, error = clean_and_validate_sequence(input_sequence)
 
28
  ]
29
 
30
  with st.expander("Input a PDB id, upload a PDB file or input a sequence", expanded=True):
31
+ pdb_id = select_pdb()
32
  uploaded_file = st.file_uploader("2.Upload PDB", type=["pdb"])
33
  input_sequence = st.text_area("3.Input sequence", "", key="input_sequence", max_chars=400)
34
  sequence, error = clean_and_validate_sequence(input_sequence)