Spaces:
Sleeping
Sleeping
Don't add EC number chars to top residues
Browse files- hexviz/attention.py +19 -14
hexviz/attention.py
CHANGED
@@ -251,6 +251,10 @@ def get_attention_pairs(
|
|
251 |
|
252 |
attention_pairs = []
|
253 |
top_residues = []
|
|
|
|
|
|
|
|
|
254 |
for i, chain in enumerate(chains):
|
255 |
ec_number = ec_numbers[i] if ec_numbers else None
|
256 |
sequence = get_sequence(chain)
|
@@ -266,38 +270,39 @@ def get_attention_pairs(
|
|
266 |
for attn_value, res_1, res_2 in attention_unidirectional:
|
267 |
try:
|
268 |
if not ec_number:
|
|
|
269 |
coord_1 = chain[res_1]["CA"].coord.tolist()
|
270 |
coord_2 = chain[res_2]["CA"].coord.tolist()
|
271 |
else:
|
272 |
-
if res_1
|
273 |
coord_1 = ec_number[res_1].coordinate
|
274 |
else:
|
275 |
-
coord_1 = chain[res_1 -
|
276 |
-
if res_2
|
277 |
coord_2 = ec_number[res_2].coordinate
|
278 |
else:
|
279 |
-
coord_2 = chain[res_2 -
|
280 |
|
281 |
except KeyError:
|
282 |
continue
|
283 |
|
284 |
attention_pairs.append((attn_value, coord_1, coord_2))
|
285 |
-
|
286 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
|
288 |
top_n_residues = sorted(
|
289 |
residue_attention.items(), key=lambda x: x[1], reverse=True
|
290 |
)[:top_n]
|
291 |
|
292 |
for res, attn_sum in top_n_residues:
|
293 |
-
|
294 |
-
coord = chain[res]["CA"].coord.tolist()
|
295 |
-
else:
|
296 |
-
if res < 4:
|
297 |
-
# Ignore EC tag chars as these can't be labeled
|
298 |
-
continue
|
299 |
-
else:
|
300 |
-
coord = chain[res - 4]["CA"].coord.tolist()
|
301 |
top_residues.append((attn_sum, coord, chain.id, res))
|
302 |
|
303 |
return attention_pairs, top_residues
|
|
|
251 |
|
252 |
attention_pairs = []
|
253 |
top_residues = []
|
254 |
+
|
255 |
+
ec_tag_length = 4
|
256 |
+
is_tag = lambda x: x < ec_tag_length
|
257 |
+
|
258 |
for i, chain in enumerate(chains):
|
259 |
ec_number = ec_numbers[i] if ec_numbers else None
|
260 |
sequence = get_sequence(chain)
|
|
|
270 |
for attn_value, res_1, res_2 in attention_unidirectional:
|
271 |
try:
|
272 |
if not ec_number:
|
273 |
+
# Should you add 1 here? Arent chains 1 indexed and res indexeds 0 indexed
|
274 |
coord_1 = chain[res_1]["CA"].coord.tolist()
|
275 |
coord_2 = chain[res_2]["CA"].coord.tolist()
|
276 |
else:
|
277 |
+
if is_tag(res_1):
|
278 |
coord_1 = ec_number[res_1].coordinate
|
279 |
else:
|
280 |
+
coord_1 = chain[res_1 - ec_tag_length]["CA"].coord.tolist()
|
281 |
+
if is_tag(res_2):
|
282 |
coord_2 = ec_number[res_2].coordinate
|
283 |
else:
|
284 |
+
coord_2 = chain[res_2 - ec_tag_length]["CA"].coord.tolist()
|
285 |
|
286 |
except KeyError:
|
287 |
continue
|
288 |
|
289 |
attention_pairs.append((attn_value, coord_1, coord_2))
|
290 |
+
if not ec_number:
|
291 |
+
residue_attention[res_1] = residue_attention.get(res_1, 0) + attn_value
|
292 |
+
residue_attention[res_2] = residue_attention.get(res_2, 0) + attn_value
|
293 |
+
else:
|
294 |
+
for res in [res_1, res_2]:
|
295 |
+
if not is_tag(res):
|
296 |
+
residue_attention[res - ec_tag_length] = (
|
297 |
+
residue_attention.get(res - ec_tag_length, 0) + attn_value
|
298 |
+
)
|
299 |
|
300 |
top_n_residues = sorted(
|
301 |
residue_attention.items(), key=lambda x: x[1], reverse=True
|
302 |
)[:top_n]
|
303 |
|
304 |
for res, attn_sum in top_n_residues:
|
305 |
+
coord = chain[res]["CA"].coord.tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
top_residues.append((attn_sum, coord, chain.id, res))
|
307 |
|
308 |
return attention_pairs, top_residues
|