Spaces:
Sleeping
Sleeping
Add ModelType enum and Model class to hold layers and head count
Browse files
protention/attention.py
CHANGED
@@ -2,14 +2,24 @@ from enum import Enum
|
|
2 |
from io import StringIO
|
3 |
from urllib import request
|
4 |
|
|
|
5 |
import torch
|
6 |
from Bio.PDB import PDBParser, Polypeptide, Structure
|
7 |
from tape import ProteinBertModel, TAPETokenizer
|
8 |
from transformers import T5EncoderModel, T5Tokenizer
|
9 |
|
10 |
|
11 |
-
class
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
def get_structure(pdb_code: str) -> Structure:
|
15 |
"""
|
@@ -56,9 +66,9 @@ def get_tape_bert() -> tuple[TAPETokenizer, ProteinBertModel]:
|
|
56 |
model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)
|
57 |
return tokenizer, model
|
58 |
|
59 |
-
|
60 |
def get_attention(
|
61 |
-
pdb_code: str, model:
|
62 |
):
|
63 |
"""
|
64 |
Get attention from T5
|
@@ -70,8 +80,8 @@ def get_attention(
|
|
70 |
# TODO handle multiple sequences
|
71 |
sequence = sequences[0]
|
72 |
|
73 |
-
match model:
|
74 |
-
case
|
75 |
tokenizer, model = get_tape_bert()
|
76 |
token_idxs = tokenizer.encode(sequence).tolist()
|
77 |
inputs = torch.tensor(token_idxs).unsqueeze(0)
|
@@ -80,9 +90,10 @@ def get_attention(
|
|
80 |
# Remove attention from <CLS> (first) and <SEP> (last) token
|
81 |
attns = [attn[:, :, 1:-1, 1:-1] for attn in attns]
|
82 |
attns = torch.stack([attn.squeeze(0) for attn in attns])
|
83 |
-
case
|
84 |
# Space separate sequences
|
85 |
sequences = [" ".join(sequence) for sequence in sequences]
|
86 |
tokenizer, model = get_protT5()
|
87 |
|
88 |
-
return attns
|
|
|
|
2 |
from io import StringIO
|
3 |
from urllib import request
|
4 |
|
5 |
+
import streamlit as st
|
6 |
import torch
|
7 |
from Bio.PDB import PDBParser, Polypeptide, Structure
|
8 |
from tape import ProteinBertModel, TAPETokenizer
|
9 |
from transformers import T5EncoderModel, T5Tokenizer
|
10 |
|
11 |
|
12 |
+
class ModelType(str, Enum):
|
13 |
+
TAPE_BERT = "bert-base"
|
14 |
+
PROT_T5 = "prot_t5_xl_half_uniref50-enc"
|
15 |
+
|
16 |
+
|
17 |
+
class Model:
|
18 |
+
def __init__(self, name, layers, heads):
|
19 |
+
self.name: ModelType = name
|
20 |
+
self.layers: int = layers
|
21 |
+
self.heads: int = heads
|
22 |
+
|
23 |
|
24 |
def get_structure(pdb_code: str) -> Structure:
|
25 |
"""
|
|
|
66 |
model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)
|
67 |
return tokenizer, model
|
68 |
|
69 |
+
@st.cache
|
70 |
def get_attention(
|
71 |
+
pdb_code: str, model: ModelType = ModelType.TAPE_BERT
|
72 |
):
|
73 |
"""
|
74 |
Get attention from T5
|
|
|
80 |
# TODO handle multiple sequences
|
81 |
sequence = sequences[0]
|
82 |
|
83 |
+
match model.name:
|
84 |
+
case ModelType.TAPE_BERT:
|
85 |
tokenizer, model = get_tape_bert()
|
86 |
token_idxs = tokenizer.encode(sequence).tolist()
|
87 |
inputs = torch.tensor(token_idxs).unsqueeze(0)
|
|
|
90 |
# Remove attention from <CLS> (first) and <SEP> (last) token
|
91 |
attns = [attn[:, :, 1:-1, 1:-1] for attn in attns]
|
92 |
attns = torch.stack([attn.squeeze(0) for attn in attns])
|
93 |
+
case ModelType.PROT_T5:
|
94 |
# Space separate sequences
|
95 |
sequences = [" ".join(sequence) for sequence in sequences]
|
96 |
tokenizer, model = get_protT5()
|
97 |
|
98 |
+
return attns
|
99 |
+
|
protention/streamlit/Attention_On_Structure.py
CHANGED
@@ -3,21 +3,31 @@ import stmol
|
|
3 |
import streamlit as st
|
4 |
from stmol import showmol
|
5 |
|
|
|
|
|
6 |
st.sidebar.title("pLM Attention Visualization")
|
7 |
|
8 |
st.title("pLM Attention Visualization")
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
pdb_id = st.text_input("PDB ID", "4RW0")
|
11 |
-
chain_id = None
|
12 |
|
13 |
left, right = st.columns(2)
|
14 |
with left:
|
15 |
-
layer = st.number_input("Layer", value=
|
16 |
with right:
|
17 |
-
head = st.number_input("Head", value=
|
18 |
|
19 |
min_attn = st.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.15)
|
20 |
|
|
|
21 |
|
22 |
def get_3dview(pdb):
|
23 |
xyzview = py3Dmol.view(query=f"pdb:{pdb}")
|
|
|
3 |
import streamlit as st
|
4 |
from stmol import showmol
|
5 |
|
6 |
+
from protention.attention import Model, ModelType, get_attention
|
7 |
+
|
8 |
st.sidebar.title("pLM Attention Visualization")
|
9 |
|
10 |
st.title("pLM Attention Visualization")
|
11 |
|
12 |
+
# Define list of model types
|
13 |
+
models = [
|
14 |
+
Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
|
15 |
+
]
|
16 |
+
|
17 |
+
selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=0)
|
18 |
+
selected_model = next((model for model in models if model.name.value == selected_model_name), None)
|
19 |
+
|
20 |
pdb_id = st.text_input("PDB ID", "4RW0")
|
|
|
21 |
|
22 |
left, right = st.columns(2)
|
23 |
with left:
|
24 |
+
layer = st.number_input("Layer", value=1, min_value=1, max_value=selected_model.layers)
|
25 |
with right:
|
26 |
+
head = st.number_input("Head", value=1, min_value=1, max_value=selected_model.heads)
|
27 |
|
28 |
min_attn = st.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.15)
|
29 |
|
30 |
+
attention = get_attention(pdb_id, model=selected_model.name)
|
31 |
|
32 |
def get_3dview(pdb):
|
33 |
xyzview = py3Dmol.view(query=f"pdb:{pdb}")
|
tests/test_attention.py
CHANGED
@@ -2,7 +2,7 @@ import torch
|
|
2 |
from Bio.PDB.Structure import Structure
|
3 |
from transformers import T5EncoderModel, T5Tokenizer
|
4 |
|
5 |
-
from protention.attention import (
|
6 |
get_sequences, get_structure)
|
7 |
|
8 |
|
@@ -38,7 +38,7 @@ def test_get_protT5():
|
|
38 |
|
39 |
def test_get_attention_tape():
|
40 |
|
41 |
-
result = get_attention("1AKE", model=
|
42 |
|
43 |
assert result is not None
|
44 |
assert result.shape == torch.Size([12,12,456,456])
|
|
|
2 |
from Bio.PDB.Structure import Structure
|
3 |
from transformers import T5EncoderModel, T5Tokenizer
|
4 |
|
5 |
+
from protention.attention import (ModelType, get_attention, get_protT5,
|
6 |
get_sequences, get_structure)
|
7 |
|
8 |
|
|
|
38 |
|
39 |
def test_get_attention_tape():
|
40 |
|
41 |
+
result = get_attention("1AKE", model=ModelType.tape_bert)
|
42 |
|
43 |
assert result is not None
|
44 |
assert result.shape == torch.Size([12,12,456,456])
|