Spaces:
Sleeping
Sleeping
Visualize EC tags for multiple chains
Browse files- hexviz/attention.py +3 -2
- hexviz/🧬Attention_Visualization.py +48 -33
hexviz/attention.py
CHANGED
@@ -241,7 +241,7 @@ def get_attention_pairs(
|
|
241 |
threshold: int = 0.2,
|
242 |
model_type: ModelType = ModelType.TAPE_BERT,
|
243 |
top_n: int = 2,
|
244 |
-
|
245 |
):
|
246 |
structure = PDBParser().get_structure("pdb", StringIO(pdb_str))
|
247 |
if chain_ids:
|
@@ -251,7 +251,8 @@ def get_attention_pairs(
|
|
251 |
|
252 |
attention_pairs = []
|
253 |
top_residues = []
|
254 |
-
for chain in chains:
|
|
|
255 |
sequence = get_sequence(chain)
|
256 |
attention = get_attention(
|
257 |
sequence=sequence, model_type=model_type, ec_number=ec_number
|
|
|
241 |
threshold: int = 0.2,
|
242 |
model_type: ModelType = ModelType.TAPE_BERT,
|
243 |
top_n: int = 2,
|
244 |
+
ec_numbers: list[list[ECNumber]] | None = None,
|
245 |
):
|
246 |
structure = PDBParser().get_structure("pdb", StringIO(pdb_str))
|
247 |
if chain_ids:
|
|
|
251 |
|
252 |
attention_pairs = []
|
253 |
top_residues = []
|
254 |
+
for i, chain in enumerate(chains):
|
255 |
+
ec_number = ec_numbers[i] if ec_numbers else None
|
256 |
sequence = get_sequence(chain)
|
257 |
attention = get_attention(
|
258 |
sequence=sequence, model_type=model_type, ec_number=ec_number
|
hexviz/🧬Attention_Visualization.py
CHANGED
@@ -135,37 +135,51 @@ if selected_model.name == ModelType.ZymCTRL:
|
|
135 |
|
136 |
if ec_number:
|
137 |
if selected_chains:
|
138 |
-
|
139 |
ch for ch in structure.get_chains() if ch.id in selected_chains
|
140 |
]
|
141 |
else:
|
142 |
-
|
143 |
-
the_chain = all_chains[0]
|
144 |
-
res_1 = the_chain[1]["CA"].coord.tolist()
|
145 |
-
res_2 = the_chain[2]["CA"].coord.tolist()
|
146 |
|
147 |
-
|
148 |
-
vector = [res_2[i] - res_1[i] for i in range(3)]
|
149 |
-
|
150 |
-
# Reverse the vector
|
151 |
-
reverse_vector = [-v for v in vector]
|
152 |
-
|
153 |
-
# Normalize the reverse vector
|
154 |
-
reverse_vector_normalized = np.array(reverse_vector) / np.linalg.norm(
|
155 |
-
reverse_vector
|
156 |
-
)
|
157 |
-
radius = 1
|
158 |
-
coordinates = [
|
159 |
-
[res_1[j] + i * 2 * radius * reverse_vector_normalized[j] for j in range(3)]
|
160 |
-
for i in range(4)
|
161 |
-
]
|
162 |
colors = ["blue", "green", "orange", "red"]
|
|
|
163 |
EC_numbers = ec_number.split(".")
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
st.sidebar.write("Visualized as colored spheres: " + ".".join(EC_colored))
|
170 |
|
171 |
|
@@ -177,7 +191,7 @@ attention_pairs, top_residues = get_attention_pairs(
|
|
177 |
threshold=min_attn,
|
178 |
model_type=selected_model.name,
|
179 |
top_n=n_highest_resis,
|
180 |
-
|
181 |
)
|
182 |
|
183 |
sorted_by_attention = sorted(attention_pairs, key=lambda x: x[0], reverse=True)
|
@@ -219,13 +233,14 @@ def get_3dview(pdb):
|
|
219 |
)
|
220 |
|
221 |
if selected_model.name == ModelType.ZymCTRL and ec_number:
|
222 |
-
for
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
|
|
229 |
|
230 |
if label_resi:
|
231 |
for hl_resi in hl_resi_list:
|
|
|
135 |
|
136 |
if ec_number:
|
137 |
if selected_chains:
|
138 |
+
shown_chains = [
|
139 |
ch for ch in structure.get_chains() if ch.id in selected_chains
|
140 |
]
|
141 |
else:
|
142 |
+
shown_chains = list(structure.get_chains())
|
|
|
|
|
|
|
143 |
|
144 |
+
EC_tags = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
colors = ["blue", "green", "orange", "red"]
|
146 |
+
radius = 1
|
147 |
EC_numbers = ec_number.split(".")
|
148 |
+
for ch in shown_chains:
|
149 |
+
first_residues = []
|
150 |
+
i = 1
|
151 |
+
while len(first_residues) < 2:
|
152 |
+
try:
|
153 |
+
first_residues.append(ch[i]["CA"].coord.tolist())
|
154 |
+
except KeyError:
|
155 |
+
pass
|
156 |
+
i += 1
|
157 |
+
res_1, res_2 = first_residues
|
158 |
+
|
159 |
+
# Calculate the vector from res_1 to res_2
|
160 |
+
vector = [res_2[i] - res_1[i] for i in range(3)]
|
161 |
+
|
162 |
+
# Reverse the vector
|
163 |
+
reverse_vector = [-v for v in vector]
|
164 |
+
|
165 |
+
# Normalize the reverse vector
|
166 |
+
reverse_vector_normalized = np.array(reverse_vector) / np.linalg.norm(
|
167 |
+
reverse_vector
|
168 |
+
)
|
169 |
+
coordinates = [
|
170 |
+
[
|
171 |
+
res_1[j] + i * 2 * radius * reverse_vector_normalized[j]
|
172 |
+
for j in range(3)
|
173 |
+
]
|
174 |
+
for i in range(4)
|
175 |
+
]
|
176 |
+
EC_tag = [
|
177 |
+
ECNumber(number=num, coordinate=coord, color=color, radius=radius)
|
178 |
+
for num, coord, color in zip(EC_numbers, coordinates, colors)
|
179 |
+
]
|
180 |
+
EC_tags.append(EC_tag)
|
181 |
+
|
182 |
+
EC_colored = [f":{color}[{num}]" for num, color in zip(EC_numbers, colors)]
|
183 |
st.sidebar.write("Visualized as colored spheres: " + ".".join(EC_colored))
|
184 |
|
185 |
|
|
|
191 |
threshold=min_attn,
|
192 |
model_type=selected_model.name,
|
193 |
top_n=n_highest_resis,
|
194 |
+
ec_numbers=EC_tags if ec_number else None,
|
195 |
)
|
196 |
|
197 |
sorted_by_attention = sorted(attention_pairs, key=lambda x: x[0], reverse=True)
|
|
|
233 |
)
|
234 |
|
235 |
if selected_model.name == ModelType.ZymCTRL and ec_number:
|
236 |
+
for EC_tag in EC_tags:
|
237 |
+
for EC_num in EC_tag:
|
238 |
+
stmol.add_sphere(
|
239 |
+
xyzview,
|
240 |
+
spcenter=EC_num.coordinate,
|
241 |
+
radius=EC_num.radius,
|
242 |
+
spColor=EC_num.color,
|
243 |
+
)
|
244 |
|
245 |
if label_resi:
|
246 |
for hl_resi in hl_resi_list:
|