Spaces:
Sleeping
Sleeping
Clean up view
Browse files- hexviz/app.py +3 -5
hexviz/app.py
CHANGED
@@ -3,27 +3,25 @@ import stmol
|
|
3 |
import streamlit as st
|
4 |
from stmol import showmol
|
5 |
|
6 |
-
from hexviz.attention import get_attention_pairs
|
7 |
from hexviz.models import Model, ModelType
|
8 |
|
9 |
st.title("Attention Visualization on proteins")
|
10 |
|
11 |
"""
|
12 |
-
Visualize attention weights on protein structures for the protein language models
|
13 |
Pick a PDB ID, layer and head to visualize attention.
|
14 |
"""
|
15 |
|
16 |
models = [
|
17 |
-
# Model(name=ModelType.ProtGPT2, layers=36, heads=20),
|
18 |
Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
|
19 |
Model(name=ModelType.ZymCTRL, layers=36, heads=16),
|
20 |
-
# Model(name=ModelType.PROT_T5, layers=24, heads=32),
|
21 |
]
|
22 |
|
23 |
selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=0)
|
24 |
selected_model = next((model for model in models if model.name.value == selected_model_name), None)
|
25 |
|
26 |
-
pdb_id = st.text_input("PDB ID", "
|
27 |
|
28 |
left, right = st.columns(2)
|
29 |
with left:
|
|
|
3 |
import streamlit as st
|
4 |
from stmol import showmol
|
5 |
|
6 |
+
from hexviz.attention import get_attention_pairs, get_structure
|
7 |
from hexviz.models import Model, ModelType
|
8 |
|
9 |
st.title("Attention Visualization on proteins")
|
10 |
|
11 |
"""
|
12 |
+
Visualize attention weights on protein structures for the protein language models TAPE-BERT and ZymCTRL.
|
13 |
Pick a PDB ID, layer and head to visualize attention.
|
14 |
"""
|
15 |
|
16 |
models = [
|
|
|
17 |
Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
|
18 |
Model(name=ModelType.ZymCTRL, layers=36, heads=16),
|
|
|
19 |
]
|
20 |
|
21 |
selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=0)
|
22 |
selected_model = next((model for model in models if model.name.value == selected_model_name), None)
|
23 |
|
24 |
+
pdb_id = st.text_input("PDB ID", "1I60")
|
25 |
|
26 |
left, right = st.columns(2)
|
27 |
with left:
|