Spaces:
Sleeping
Sleeping
WIP visualize EC number and attention to it
Browse filesVisualizes EC tags as spheres on the protein structure.
- hexviz/attention.py +64 -4
- hexviz/ec_number.py +11 -0
- hexviz/🧬Attention_Visualization.py +65 -7
hexviz/attention.py
CHANGED
@@ -6,6 +6,7 @@ import streamlit as st
|
|
6 |
import torch
|
7 |
from Bio.PDB import PDBParser, Polypeptide, Structure
|
8 |
|
|
|
9 |
from hexviz.models import (
|
10 |
ModelType,
|
11 |
get_prot_bert,
|
@@ -98,11 +99,39 @@ def clean_and_validate_sequence(sequence: str) -> tuple[str, str | None]:
|
|
98 |
return cleaned_sequence, None
|
99 |
|
100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
@st.cache
|
102 |
def get_attention(
|
103 |
sequence: str,
|
104 |
model_type: ModelType = ModelType.TAPE_BERT,
|
105 |
remove_special_tokens: bool = True,
|
|
|
106 |
):
|
107 |
"""
|
108 |
Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights
|
@@ -122,6 +151,10 @@ def get_attention(
|
|
122 |
|
123 |
elif model_type == ModelType.ZymCTRL:
|
124 |
tokenizer, model = get_zymctrl()
|
|
|
|
|
|
|
|
|
125 |
inputs = tokenizer(sequence, return_tensors="pt").input_ids.to(device)
|
126 |
attention_mask = tokenizer(sequence, return_tensors="pt").attention_mask.to(
|
127 |
device
|
@@ -133,6 +166,12 @@ def get_attention(
|
|
133 |
)
|
134 |
attentions = outputs.attentions
|
135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
# torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
|
137 |
attention_squeezed = [torch.squeeze(attention) for attention in attentions]
|
138 |
# ([n_heads, n_res, n_res]*n_layers) -> [n_layers, n_heads, n_res, n_res]
|
@@ -202,6 +241,7 @@ def get_attention_pairs(
|
|
202 |
threshold: int = 0.2,
|
203 |
model_type: ModelType = ModelType.TAPE_BERT,
|
204 |
top_n: int = 2,
|
|
|
205 |
):
|
206 |
structure = PDBParser().get_structure("pdb", StringIO(pdb_str))
|
207 |
if chain_ids:
|
@@ -213,7 +253,9 @@ def get_attention_pairs(
|
|
213 |
top_residues = []
|
214 |
for chain in chains:
|
215 |
sequence = get_sequence(chain)
|
216 |
-
attention = get_attention(
|
|
|
|
|
217 |
attention_unidirectional = unidirectional_avg_filtered(
|
218 |
attention, layer, head, threshold
|
219 |
)
|
@@ -222,8 +264,19 @@ def get_attention_pairs(
|
|
222 |
residue_attention = {}
|
223 |
for attn_value, res_1, res_2 in attention_unidirectional:
|
224 |
try:
|
225 |
-
|
226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
except KeyError:
|
228 |
continue
|
229 |
|
@@ -236,7 +289,14 @@ def get_attention_pairs(
|
|
236 |
)[:top_n]
|
237 |
|
238 |
for res, attn_sum in top_n_residues:
|
239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
top_residues.append((attn_sum, coord, chain.id, res))
|
241 |
|
242 |
return attention_pairs, top_residues
|
|
|
6 |
import torch
|
7 |
from Bio.PDB import PDBParser, Polypeptide, Structure
|
8 |
|
9 |
+
from hexviz.ec_number import ECNumber
|
10 |
from hexviz.models import (
|
11 |
ModelType,
|
12 |
get_prot_bert,
|
|
|
99 |
return cleaned_sequence, None
|
100 |
|
101 |
|
102 |
+
def remove_special_tokens_and_periods(attentions_tuple, sequence, tokenizer):
|
103 |
+
tokens = tokenizer.tokenize(sequence)
|
104 |
+
|
105 |
+
indices_to_remove = [
|
106 |
+
i
|
107 |
+
for i, token in enumerate(tokens)
|
108 |
+
if token in {".", "<sep>", "<start>", "<end>", "<pad>"}
|
109 |
+
]
|
110 |
+
|
111 |
+
new_attentions = []
|
112 |
+
|
113 |
+
for attentions in attentions_tuple:
|
114 |
+
# Remove rows and columns corresponding to special tokens and periods
|
115 |
+
for idx in sorted(indices_to_remove, reverse=True):
|
116 |
+
attentions = torch.cat(
|
117 |
+
(attentions[:, :, :idx], attentions[:, :, idx + 1 :]), dim=2
|
118 |
+
)
|
119 |
+
attentions = torch.cat(
|
120 |
+
(attentions[:, :, :, :idx], attentions[:, :, :, idx + 1 :]), dim=3
|
121 |
+
)
|
122 |
+
|
123 |
+
# Append the modified attentions tensor to the new_attentions list
|
124 |
+
new_attentions.append(attentions)
|
125 |
+
|
126 |
+
return new_attentions
|
127 |
+
|
128 |
+
|
129 |
@st.cache
|
130 |
def get_attention(
|
131 |
sequence: str,
|
132 |
model_type: ModelType = ModelType.TAPE_BERT,
|
133 |
remove_special_tokens: bool = True,
|
134 |
+
ec_number: list[ECNumber] = None,
|
135 |
):
|
136 |
"""
|
137 |
Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights
|
|
|
151 |
|
152 |
elif model_type == ModelType.ZymCTRL:
|
153 |
tokenizer, model = get_zymctrl()
|
154 |
+
|
155 |
+
if ec_number:
|
156 |
+
sequence = f"{'.'.join([ec.number for ec in ec_number])}<sep><start>{sequence}<end><pad>"
|
157 |
+
|
158 |
inputs = tokenizer(sequence, return_tensors="pt").input_ids.to(device)
|
159 |
attention_mask = tokenizer(sequence, return_tensors="pt").attention_mask.to(
|
160 |
device
|
|
|
166 |
)
|
167 |
attentions = outputs.attentions
|
168 |
|
169 |
+
if ec_number:
|
170 |
+
# Remove attention to special tokens and periods separating EC number components
|
171 |
+
attentions = remove_special_tokens_and_periods(
|
172 |
+
attentions, sequence, tokenizer
|
173 |
+
)
|
174 |
+
|
175 |
# torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
|
176 |
attention_squeezed = [torch.squeeze(attention) for attention in attentions]
|
177 |
# ([n_heads, n_res, n_res]*n_layers) -> [n_layers, n_heads, n_res, n_res]
|
|
|
241 |
threshold: int = 0.2,
|
242 |
model_type: ModelType = ModelType.TAPE_BERT,
|
243 |
top_n: int = 2,
|
244 |
+
ec_number: list[ECNumber] | None = None,
|
245 |
):
|
246 |
structure = PDBParser().get_structure("pdb", StringIO(pdb_str))
|
247 |
if chain_ids:
|
|
|
253 |
top_residues = []
|
254 |
for chain in chains:
|
255 |
sequence = get_sequence(chain)
|
256 |
+
attention = get_attention(
|
257 |
+
sequence=sequence, model_type=model_type, ec_number=ec_number
|
258 |
+
)
|
259 |
attention_unidirectional = unidirectional_avg_filtered(
|
260 |
attention, layer, head, threshold
|
261 |
)
|
|
|
264 |
residue_attention = {}
|
265 |
for attn_value, res_1, res_2 in attention_unidirectional:
|
266 |
try:
|
267 |
+
if not ec_number:
|
268 |
+
coord_1 = chain[res_1]["CA"].coord.tolist()
|
269 |
+
coord_2 = chain[res_2]["CA"].coord.tolist()
|
270 |
+
else:
|
271 |
+
if res_1 < 4:
|
272 |
+
coord_1 = ec_number[res_1].coordinate
|
273 |
+
else:
|
274 |
+
coord_1 = chain[res_1 - 4]["CA"].coord.tolist()
|
275 |
+
if res_2 < 4:
|
276 |
+
coord_2 = ec_number[res_2].coordinate
|
277 |
+
else:
|
278 |
+
coord_2 = chain[res_2 - 4]["CA"].coord.tolist()
|
279 |
+
|
280 |
except KeyError:
|
281 |
continue
|
282 |
|
|
|
289 |
)[:top_n]
|
290 |
|
291 |
for res, attn_sum in top_n_residues:
|
292 |
+
if not ec_number:
|
293 |
+
coord = chain[res]["CA"].coord.tolist()
|
294 |
+
else:
|
295 |
+
if res < 4:
|
296 |
+
# Ignore EC tag chars as these can't be labeled
|
297 |
+
continue
|
298 |
+
else:
|
299 |
+
coord = chain[res - 4]["CA"].coord.tolist()
|
300 |
top_residues.append((attn_sum, coord, chain.id, res))
|
301 |
|
302 |
return attention_pairs, top_residues
|
hexviz/ec_number.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class ECNumber:
|
2 |
+
def __init__(self, number, coordinate, color, radius):
|
3 |
+
self.number = number
|
4 |
+
self.coordinate = coordinate
|
5 |
+
self.color = color
|
6 |
+
self.radius = radius
|
7 |
+
|
8 |
+
def __str__(self):
|
9 |
+
return (
|
10 |
+
f"(EC: {self.number}, Coordinate: {self.coordinate}, Color: {self.color})"
|
11 |
+
)
|
hexviz/🧬Attention_Visualization.py
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
import pandas as pd
|
2 |
import py3Dmol
|
3 |
import stmol
|
@@ -9,10 +12,10 @@ from hexviz.attention import (
|
|
9 |
get_attention_pairs,
|
10 |
get_chains,
|
11 |
)
|
|
|
|
|
12 |
from hexviz.models import Model, ModelType
|
13 |
from hexviz.view import menu_items, select_model, select_pdb, select_protein
|
14 |
-
from hexviz.config import URL
|
15 |
-
|
16 |
|
17 |
st.set_page_config(layout="centered", menu_items=menu_items)
|
18 |
st.title("Attention Visualization on proteins")
|
@@ -110,15 +113,60 @@ with right:
|
|
110 |
)
|
111 |
head = head_one - 1
|
112 |
|
113 |
-
|
114 |
if selected_model.name == ModelType.ZymCTRL:
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
try:
|
116 |
-
|
117 |
except KeyError:
|
118 |
pass
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
|
124 |
attention_pairs, top_residues = get_attention_pairs(
|
@@ -129,6 +177,7 @@ attention_pairs, top_residues = get_attention_pairs(
|
|
129 |
threshold=min_attn,
|
130 |
model_type=selected_model.name,
|
131 |
top_n=n_highest_resis,
|
|
|
132 |
)
|
133 |
|
134 |
sorted_by_attention = sorted(attention_pairs, key=lambda x: x[0], reverse=True)
|
@@ -169,6 +218,15 @@ def get_3dview(pdb):
|
|
169 |
dashed=False,
|
170 |
)
|
171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
if label_resi:
|
173 |
for hl_resi in hl_resi_list:
|
174 |
xyzview.addResLabels(
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
import pandas as pd
|
5 |
import py3Dmol
|
6 |
import stmol
|
|
|
12 |
get_attention_pairs,
|
13 |
get_chains,
|
14 |
)
|
15 |
+
from hexviz.config import URL
|
16 |
+
from hexviz.ec_number import ECNumber
|
17 |
from hexviz.models import Model, ModelType
|
18 |
from hexviz.view import menu_items, select_model, select_pdb, select_protein
|
|
|
|
|
19 |
|
20 |
st.set_page_config(layout="centered", menu_items=menu_items)
|
21 |
st.title("Attention Visualization on proteins")
|
|
|
113 |
)
|
114 |
head = head_one - 1
|
115 |
|
116 |
+
ec_number = ""
|
117 |
if selected_model.name == ModelType.ZymCTRL:
|
118 |
+
st.sidebar.markdown(
|
119 |
+
"""
|
120 |
+
ZymCTRL EC number
|
121 |
+
---
|
122 |
+
"""
|
123 |
+
)
|
124 |
try:
|
125 |
+
ec_number = structure.header["compound"]["1"]["ec"]
|
126 |
except KeyError:
|
127 |
pass
|
128 |
+
ec_number = st.sidebar.text_input("Enzyme Comission number (EC)", ec_number)
|
129 |
+
|
130 |
+
# Validate EC number
|
131 |
+
if not re.match(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$", ec_number):
|
132 |
+
st.sidebar.error(
|
133 |
+
"Please enter a valid Enzyme Commission number in the format of 4 integers separated by periods (e.g., 1.2.3.21)"
|
134 |
+
)
|
135 |
+
|
136 |
+
if ec_number:
|
137 |
+
if selected_chains:
|
138 |
+
all_chains = [
|
139 |
+
ch for ch in structure.get_chains() if ch.id in selected_chains
|
140 |
+
]
|
141 |
+
else:
|
142 |
+
all_chains = list(structure.get_chains())
|
143 |
+
the_chain = all_chains[0]
|
144 |
+
res_1 = the_chain[1]["CA"].coord.tolist()
|
145 |
+
res_2 = the_chain[2]["CA"].coord.tolist()
|
146 |
+
|
147 |
+
# Calculate the vector from res_1 to res_2
|
148 |
+
vector = [res_2[i] - res_1[i] for i in range(3)]
|
149 |
+
|
150 |
+
# Reverse the vector
|
151 |
+
reverse_vector = [-v for v in vector]
|
152 |
+
|
153 |
+
# Normalize the reverse vector
|
154 |
+
reverse_vector_normalized = np.array(reverse_vector) / np.linalg.norm(
|
155 |
+
reverse_vector
|
156 |
+
)
|
157 |
+
radius = 1
|
158 |
+
coordinates = [
|
159 |
+
[res_1[j] + i * 2 * radius * reverse_vector_normalized[j] for j in range(3)]
|
160 |
+
for i in range(4)
|
161 |
+
]
|
162 |
+
colors = ["blue", "green", "orange", "red"]
|
163 |
+
EC_numbers = ec_number.split(".")
|
164 |
+
EC_tag = [
|
165 |
+
ECNumber(number=num, coordinate=coord, color=color, radius=radius)
|
166 |
+
for num, coord, color in zip(EC_numbers, coordinates, colors)
|
167 |
+
]
|
168 |
+
EC_colored = [f":{color}[{EC.number}]" for EC, color in zip(EC_tag, colors)]
|
169 |
+
st.sidebar.write("Visualized as colored spheres: " + ".".join(EC_colored))
|
170 |
|
171 |
|
172 |
attention_pairs, top_residues = get_attention_pairs(
|
|
|
177 |
threshold=min_attn,
|
178 |
model_type=selected_model.name,
|
179 |
top_n=n_highest_resis,
|
180 |
+
ec_number=EC_tag if ec_number else None,
|
181 |
)
|
182 |
|
183 |
sorted_by_attention = sorted(attention_pairs, key=lambda x: x[0], reverse=True)
|
|
|
218 |
dashed=False,
|
219 |
)
|
220 |
|
221 |
+
if selected_model.name == ModelType.ZymCTRL and ec_number:
|
222 |
+
for EC_num in EC_tag:
|
223 |
+
stmol.add_sphere(
|
224 |
+
xyzview,
|
225 |
+
spcenter=EC_num.coordinate,
|
226 |
+
radius=EC_num.radius,
|
227 |
+
spColor=EC_num.color,
|
228 |
+
)
|
229 |
+
|
230 |
if label_resi:
|
231 |
for hl_resi in hl_resi_list:
|
232 |
xyzview.addResLabels(
|