Spaces:
Sleeping
Sleeping
Label top n attention pairs
Browse filesTo have the most interesting attention pairs labeled and
identifyable immediately.
- hexviz/app.py +16 -1
- hexviz/attention.py +1 -1
hexviz/app.py
CHANGED
@@ -46,6 +46,8 @@ with right:
|
|
46 |
|
47 |
with st.expander("Attention parameters", expanded=False):
|
48 |
min_attn = st.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.1)
|
|
|
|
|
49 |
|
50 |
# TODO add avg or max attention as params
|
51 |
|
@@ -58,6 +60,9 @@ with st.expander("Attention parameters", expanded=False):
|
|
58 |
|
59 |
attention_pairs = get_attention_pairs(pdb_id, chain_ids=selected_chains, layer=layer, head=head, threshold=min_attn, model_type=selected_model.name)
|
60 |
|
|
|
|
|
|
|
61 |
def get_3dview(pdb):
|
62 |
xyzview = py3Dmol.view(query=f"pdb:{pdb}")
|
63 |
xyzview.setStyle({"cartoon": {"color": "spectrum"}})
|
@@ -68,13 +73,22 @@ def get_3dview(pdb):
|
|
68 |
for chain in hidden_chains:
|
69 |
xyzview.setStyle({"chain": chain},{"cross":{"hidden":"true"}})
|
70 |
|
71 |
-
for att_weight, first, second in attention_pairs:
|
72 |
stmol.add_cylinder(xyzview, start=first, end=second, cylradius=att_weight, cylColor='red', dashed=False)
|
|
|
|
|
73 |
|
74 |
if label_resi:
|
75 |
for hl_resi in hl_resi_list:
|
76 |
xyzview.addResLabels({"chain": hl_chain,"resi": hl_resi},
|
77 |
{"backgroundColor": "lightgray","fontColor": "black","backgroundOpacity": 0.5})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
return xyzview
|
79 |
|
80 |
|
@@ -82,6 +96,7 @@ xyzview = get_3dview(pdb_id)
|
|
82 |
showmol(xyzview, height=500, width=800)
|
83 |
st.markdown(f'PDB: [{pdb_id}](https://www.rcsb.org/structure/{pdb_id})', unsafe_allow_html=True)
|
84 |
|
|
|
85 |
"""
|
86 |
More models will be added soon. The attention visualization is inspired by [provis](https://github.com/salesforce/provis#provis-attention-visualizer).
|
87 |
"""
|
|
|
46 |
|
47 |
with st.expander("Attention parameters", expanded=False):
|
48 |
min_attn = st.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.1)
|
49 |
+
n_pairs = st.number_input("Num attention pairs labeled", value=2, min_value=1, max_value=100)
|
50 |
+
label_highest = st.checkbox("Label highest attention pairs", value=True)
|
51 |
|
52 |
# TODO add avg or max attention as params
|
53 |
|
|
|
60 |
|
61 |
attention_pairs = get_attention_pairs(pdb_id, chain_ids=selected_chains, layer=layer, head=head, threshold=min_attn, model_type=selected_model.name)
|
62 |
|
63 |
+
sorted_by_attention = sorted(attention_pairs, key=lambda x: x[0], reverse=True)
|
64 |
+
top_n = sorted_by_attention[:n_pairs]
|
65 |
+
|
66 |
def get_3dview(pdb):
|
67 |
xyzview = py3Dmol.view(query=f"pdb:{pdb}")
|
68 |
xyzview.setStyle({"cartoon": {"color": "spectrum"}})
|
|
|
73 |
for chain in hidden_chains:
|
74 |
xyzview.setStyle({"chain": chain},{"cross":{"hidden":"true"}})
|
75 |
|
76 |
+
for att_weight, first, second, _, _, _ in attention_pairs:
|
77 |
stmol.add_cylinder(xyzview, start=first, end=second, cylradius=att_weight, cylColor='red', dashed=False)
|
78 |
+
|
79 |
+
# get_max_attention(n_pairs)
|
80 |
|
81 |
if label_resi:
|
82 |
for hl_resi in hl_resi_list:
|
83 |
xyzview.addResLabels({"chain": hl_chain,"resi": hl_resi},
|
84 |
{"backgroundColor": "lightgray","fontColor": "black","backgroundOpacity": 0.5})
|
85 |
+
|
86 |
+
if label_highest:
|
87 |
+
for _, _, _, chain, a, b in top_n:
|
88 |
+
xyzview.addResLabels({"chain": chain,"resi": a},
|
89 |
+
{"backgroundColor": "lightgray","fontColor": "black","backgroundOpacity": 0.5})
|
90 |
+
xyzview.addResLabels({"chain": chain,"resi": b},
|
91 |
+
{"backgroundColor": "lightgray","fontColor": "black","backgroundOpacity": 0.5})
|
92 |
return xyzview
|
93 |
|
94 |
|
|
|
96 |
showmol(xyzview, height=500, width=800)
|
97 |
st.markdown(f'PDB: [{pdb_id}](https://www.rcsb.org/structure/{pdb_id})', unsafe_allow_html=True)
|
98 |
|
99 |
+
|
100 |
"""
|
101 |
More models will be added soon. The attention visualization is inspired by [provis](https://github.com/salesforce/provis#provis-attention-visualizer).
|
102 |
"""
|
hexviz/attention.py
CHANGED
@@ -125,6 +125,6 @@ def get_attention_pairs(pdb_code: str, layer: int, head: int, chain_ids: Optiona
|
|
125 |
coord_2 = chain[res_2]["CA"].coord.tolist()
|
126 |
except KeyError:
|
127 |
continue
|
128 |
-
attention_pairs.append((attn_value, coord_1, coord_2))
|
129 |
|
130 |
return attention_pairs
|
|
|
125 |
coord_2 = chain[res_2]["CA"].coord.tolist()
|
126 |
except KeyError:
|
127 |
continue
|
128 |
+
attention_pairs.append((attn_value, coord_1, coord_2, chain.id, res_1, res_2))
|
129 |
|
130 |
return attention_pairs
|