aksell commited on
Commit
4355c21
·
1 Parent(s): beb84f6

Label top n attention pairs

Browse files

To have the most interesting attention pairs labeled and
identifyable immediately.

Files changed (2) hide show
  1. hexviz/app.py +16 -1
  2. hexviz/attention.py +1 -1
hexviz/app.py CHANGED
@@ -46,6 +46,8 @@ with right:
46
 
47
  with st.expander("Attention parameters", expanded=False):
48
  min_attn = st.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.1)
 
 
49
 
50
  # TODO add avg or max attention as params
51
 
@@ -58,6 +60,9 @@ with st.expander("Attention parameters", expanded=False):
58
 
59
  attention_pairs = get_attention_pairs(pdb_id, chain_ids=selected_chains, layer=layer, head=head, threshold=min_attn, model_type=selected_model.name)
60
 
 
 
 
61
  def get_3dview(pdb):
62
  xyzview = py3Dmol.view(query=f"pdb:{pdb}")
63
  xyzview.setStyle({"cartoon": {"color": "spectrum"}})
@@ -68,13 +73,22 @@ def get_3dview(pdb):
68
  for chain in hidden_chains:
69
  xyzview.setStyle({"chain": chain},{"cross":{"hidden":"true"}})
70
 
71
- for att_weight, first, second in attention_pairs:
72
  stmol.add_cylinder(xyzview, start=first, end=second, cylradius=att_weight, cylColor='red', dashed=False)
 
 
73
 
74
  if label_resi:
75
  for hl_resi in hl_resi_list:
76
  xyzview.addResLabels({"chain": hl_chain,"resi": hl_resi},
77
  {"backgroundColor": "lightgray","fontColor": "black","backgroundOpacity": 0.5})
 
 
 
 
 
 
 
78
  return xyzview
79
 
80
 
@@ -82,6 +96,7 @@ xyzview = get_3dview(pdb_id)
82
  showmol(xyzview, height=500, width=800)
83
  st.markdown(f'PDB: [{pdb_id}](https://www.rcsb.org/structure/{pdb_id})', unsafe_allow_html=True)
84
 
 
85
  """
86
  More models will be added soon. The attention visualization is inspired by [provis](https://github.com/salesforce/provis#provis-attention-visualizer).
87
  """
 
46
 
47
  with st.expander("Attention parameters", expanded=False):
48
  min_attn = st.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.1)
49
+ n_pairs = st.number_input("Num attention pairs labeled", value=2, min_value=1, max_value=100)
50
+ label_highest = st.checkbox("Label highest attention pairs", value=True)
51
 
52
  # TODO add avg or max attention as params
53
 
 
60
 
61
  attention_pairs = get_attention_pairs(pdb_id, chain_ids=selected_chains, layer=layer, head=head, threshold=min_attn, model_type=selected_model.name)
62
 
63
+ sorted_by_attention = sorted(attention_pairs, key=lambda x: x[0], reverse=True)
64
+ top_n = sorted_by_attention[:n_pairs]
65
+
66
  def get_3dview(pdb):
67
  xyzview = py3Dmol.view(query=f"pdb:{pdb}")
68
  xyzview.setStyle({"cartoon": {"color": "spectrum"}})
 
73
  for chain in hidden_chains:
74
  xyzview.setStyle({"chain": chain},{"cross":{"hidden":"true"}})
75
 
76
+ for att_weight, first, second, _, _, _ in attention_pairs:
77
  stmol.add_cylinder(xyzview, start=first, end=second, cylradius=att_weight, cylColor='red', dashed=False)
78
+
79
+ # get_max_attention(n_pairs)
80
 
81
  if label_resi:
82
  for hl_resi in hl_resi_list:
83
  xyzview.addResLabels({"chain": hl_chain,"resi": hl_resi},
84
  {"backgroundColor": "lightgray","fontColor": "black","backgroundOpacity": 0.5})
85
+
86
+ if label_highest:
87
+ for _, _, _, chain, a, b in top_n:
88
+ xyzview.addResLabels({"chain": chain,"resi": a},
89
+ {"backgroundColor": "lightgray","fontColor": "black","backgroundOpacity": 0.5})
90
+ xyzview.addResLabels({"chain": chain,"resi": b},
91
+ {"backgroundColor": "lightgray","fontColor": "black","backgroundOpacity": 0.5})
92
  return xyzview
93
 
94
 
 
96
  showmol(xyzview, height=500, width=800)
97
  st.markdown(f'PDB: [{pdb_id}](https://www.rcsb.org/structure/{pdb_id})', unsafe_allow_html=True)
98
 
99
+
100
  """
101
  More models will be added soon. The attention visualization is inspired by [provis](https://github.com/salesforce/provis#provis-attention-visualizer).
102
  """
hexviz/attention.py CHANGED
@@ -125,6 +125,6 @@ def get_attention_pairs(pdb_code: str, layer: int, head: int, chain_ids: Optiona
125
  coord_2 = chain[res_2]["CA"].coord.tolist()
126
  except KeyError:
127
  continue
128
- attention_pairs.append((attn_value, coord_1, coord_2))
129
 
130
  return attention_pairs
 
125
  coord_2 = chain[res_2]["CA"].coord.tolist()
126
  except KeyError:
127
  continue
128
+ attention_pairs.append((attn_value, coord_1, coord_2, chain.id, res_1, res_2))
129
 
130
  return attention_pairs