aksell commited on
Commit
528cd6e
·
1 Parent(s): 14646a1

Don't add EC number chars to top residues

Browse files
Files changed (1) hide show
  1. hexviz/attention.py +19 -14
hexviz/attention.py CHANGED
@@ -251,6 +251,10 @@ def get_attention_pairs(
251
 
252
  attention_pairs = []
253
  top_residues = []
 
 
 
 
254
  for i, chain in enumerate(chains):
255
  ec_number = ec_numbers[i] if ec_numbers else None
256
  sequence = get_sequence(chain)
@@ -266,38 +270,39 @@ def get_attention_pairs(
266
  for attn_value, res_1, res_2 in attention_unidirectional:
267
  try:
268
  if not ec_number:
 
269
  coord_1 = chain[res_1]["CA"].coord.tolist()
270
  coord_2 = chain[res_2]["CA"].coord.tolist()
271
  else:
272
- if res_1 < 4:
273
  coord_1 = ec_number[res_1].coordinate
274
  else:
275
- coord_1 = chain[res_1 - 4]["CA"].coord.tolist()
276
- if res_2 < 4:
277
  coord_2 = ec_number[res_2].coordinate
278
  else:
279
- coord_2 = chain[res_2 - 4]["CA"].coord.tolist()
280
 
281
  except KeyError:
282
  continue
283
 
284
  attention_pairs.append((attn_value, coord_1, coord_2))
285
- residue_attention[res_1] = residue_attention.get(res_1, 0) + attn_value
286
- residue_attention[res_2] = residue_attention.get(res_2, 0) + attn_value
 
 
 
 
 
 
 
287
 
288
  top_n_residues = sorted(
289
  residue_attention.items(), key=lambda x: x[1], reverse=True
290
  )[:top_n]
291
 
292
  for res, attn_sum in top_n_residues:
293
- if not ec_number:
294
- coord = chain[res]["CA"].coord.tolist()
295
- else:
296
- if res < 4:
297
- # Ignore EC tag chars as these can't be labeled
298
- continue
299
- else:
300
- coord = chain[res - 4]["CA"].coord.tolist()
301
  top_residues.append((attn_sum, coord, chain.id, res))
302
 
303
  return attention_pairs, top_residues
 
251
 
252
  attention_pairs = []
253
  top_residues = []
254
+
255
+ ec_tag_length = 4
256
+ is_tag = lambda x: x < ec_tag_length
257
+
258
  for i, chain in enumerate(chains):
259
  ec_number = ec_numbers[i] if ec_numbers else None
260
  sequence = get_sequence(chain)
 
270
  for attn_value, res_1, res_2 in attention_unidirectional:
271
  try:
272
  if not ec_number:
273
+ # Should you add 1 here? Arent chains 1 indexed and res indexeds 0 indexed
274
  coord_1 = chain[res_1]["CA"].coord.tolist()
275
  coord_2 = chain[res_2]["CA"].coord.tolist()
276
  else:
277
+ if is_tag(res_1):
278
  coord_1 = ec_number[res_1].coordinate
279
  else:
280
+ coord_1 = chain[res_1 - ec_tag_length]["CA"].coord.tolist()
281
+ if is_tag(res_2):
282
  coord_2 = ec_number[res_2].coordinate
283
  else:
284
+ coord_2 = chain[res_2 - ec_tag_length]["CA"].coord.tolist()
285
 
286
  except KeyError:
287
  continue
288
 
289
  attention_pairs.append((attn_value, coord_1, coord_2))
290
+ if not ec_number:
291
+ residue_attention[res_1] = residue_attention.get(res_1, 0) + attn_value
292
+ residue_attention[res_2] = residue_attention.get(res_2, 0) + attn_value
293
+ else:
294
+ for res in [res_1, res_2]:
295
+ if not is_tag(res):
296
+ residue_attention[res - ec_tag_length] = (
297
+ residue_attention.get(res - ec_tag_length, 0) + attn_value
298
+ )
299
 
300
  top_n_residues = sorted(
301
  residue_attention.items(), key=lambda x: x[1], reverse=True
302
  )[:top_n]
303
 
304
  for res, attn_sum in top_n_residues:
305
+ coord = chain[res]["CA"].coord.tolist()
 
 
 
 
 
 
 
306
  top_residues.append((attn_sum, coord, chain.id, res))
307
 
308
  return attention_pairs, top_residues