aksell commited on
Commit
b6e0c9a
·
1 Parent(s): 9f086ee

Visualize EC tags for multiple chains

Browse files
hexviz/attention.py CHANGED
@@ -241,7 +241,7 @@ def get_attention_pairs(
241
  threshold: int = 0.2,
242
  model_type: ModelType = ModelType.TAPE_BERT,
243
  top_n: int = 2,
244
- ec_number: list[ECNumber] | None = None,
245
  ):
246
  structure = PDBParser().get_structure("pdb", StringIO(pdb_str))
247
  if chain_ids:
@@ -251,7 +251,8 @@ def get_attention_pairs(
251
 
252
  attention_pairs = []
253
  top_residues = []
254
- for chain in chains:
 
255
  sequence = get_sequence(chain)
256
  attention = get_attention(
257
  sequence=sequence, model_type=model_type, ec_number=ec_number
 
241
  threshold: int = 0.2,
242
  model_type: ModelType = ModelType.TAPE_BERT,
243
  top_n: int = 2,
244
+ ec_numbers: list[list[ECNumber]] | None = None,
245
  ):
246
  structure = PDBParser().get_structure("pdb", StringIO(pdb_str))
247
  if chain_ids:
 
251
 
252
  attention_pairs = []
253
  top_residues = []
254
+ for i, chain in enumerate(chains):
255
+ ec_number = ec_numbers[i] if ec_numbers else None
256
  sequence = get_sequence(chain)
257
  attention = get_attention(
258
  sequence=sequence, model_type=model_type, ec_number=ec_number
hexviz/🧬Attention_Visualization.py CHANGED
@@ -135,37 +135,51 @@ if selected_model.name == ModelType.ZymCTRL:
135
 
136
  if ec_number:
137
  if selected_chains:
138
- all_chains = [
139
  ch for ch in structure.get_chains() if ch.id in selected_chains
140
  ]
141
  else:
142
- all_chains = list(structure.get_chains())
143
- the_chain = all_chains[0]
144
- res_1 = the_chain[1]["CA"].coord.tolist()
145
- res_2 = the_chain[2]["CA"].coord.tolist()
146
 
147
- # Calculate the vector from res_1 to res_2
148
- vector = [res_2[i] - res_1[i] for i in range(3)]
149
-
150
- # Reverse the vector
151
- reverse_vector = [-v for v in vector]
152
-
153
- # Normalize the reverse vector
154
- reverse_vector_normalized = np.array(reverse_vector) / np.linalg.norm(
155
- reverse_vector
156
- )
157
- radius = 1
158
- coordinates = [
159
- [res_1[j] + i * 2 * radius * reverse_vector_normalized[j] for j in range(3)]
160
- for i in range(4)
161
- ]
162
  colors = ["blue", "green", "orange", "red"]
 
163
  EC_numbers = ec_number.split(".")
164
- EC_tag = [
165
- ECNumber(number=num, coordinate=coord, color=color, radius=radius)
166
- for num, coord, color in zip(EC_numbers, coordinates, colors)
167
- ]
168
- EC_colored = [f":{color}[{EC.number}]" for EC, color in zip(EC_tag, colors)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  st.sidebar.write("Visualized as colored spheres: " + ".".join(EC_colored))
170
 
171
 
@@ -177,7 +191,7 @@ attention_pairs, top_residues = get_attention_pairs(
177
  threshold=min_attn,
178
  model_type=selected_model.name,
179
  top_n=n_highest_resis,
180
- ec_number=EC_tag if ec_number else None,
181
  )
182
 
183
  sorted_by_attention = sorted(attention_pairs, key=lambda x: x[0], reverse=True)
@@ -219,13 +233,14 @@ def get_3dview(pdb):
219
  )
220
 
221
  if selected_model.name == ModelType.ZymCTRL and ec_number:
222
- for EC_num in EC_tag:
223
- stmol.add_sphere(
224
- xyzview,
225
- spcenter=EC_num.coordinate,
226
- radius=EC_num.radius,
227
- spColor=EC_num.color,
228
- )
 
229
 
230
  if label_resi:
231
  for hl_resi in hl_resi_list:
 
135
 
136
  if ec_number:
137
  if selected_chains:
138
+ shown_chains = [
139
  ch for ch in structure.get_chains() if ch.id in selected_chains
140
  ]
141
  else:
142
+ shown_chains = list(structure.get_chains())
 
 
 
143
 
144
+ EC_tags = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  colors = ["blue", "green", "orange", "red"]
146
+ radius = 1
147
  EC_numbers = ec_number.split(".")
148
+ for ch in shown_chains:
149
+ first_residues = []
150
+ i = 1
151
+ while len(first_residues) < 2:
152
+ try:
153
+ first_residues.append(ch[i]["CA"].coord.tolist())
154
+ except KeyError:
155
+ pass
156
+ i += 1
157
+ res_1, res_2 = first_residues
158
+
159
+ # Calculate the vector from res_1 to res_2
160
+ vector = [res_2[i] - res_1[i] for i in range(3)]
161
+
162
+ # Reverse the vector
163
+ reverse_vector = [-v for v in vector]
164
+
165
+ # Normalize the reverse vector
166
+ reverse_vector_normalized = np.array(reverse_vector) / np.linalg.norm(
167
+ reverse_vector
168
+ )
169
+ coordinates = [
170
+ [
171
+ res_1[j] + i * 2 * radius * reverse_vector_normalized[j]
172
+ for j in range(3)
173
+ ]
174
+ for i in range(4)
175
+ ]
176
+ EC_tag = [
177
+ ECNumber(number=num, coordinate=coord, color=color, radius=radius)
178
+ for num, coord, color in zip(EC_numbers, coordinates, colors)
179
+ ]
180
+ EC_tags.append(EC_tag)
181
+
182
+ EC_colored = [f":{color}[{num}]" for num, color in zip(EC_numbers, colors)]
183
  st.sidebar.write("Visualized as colored spheres: " + ".".join(EC_colored))
184
 
185
 
 
191
  threshold=min_attn,
192
  model_type=selected_model.name,
193
  top_n=n_highest_resis,
194
+ ec_numbers=EC_tags if ec_number else None,
195
  )
196
 
197
  sorted_by_attention = sorted(attention_pairs, key=lambda x: x[0], reverse=True)
 
233
  )
234
 
235
  if selected_model.name == ModelType.ZymCTRL and ec_number:
236
+ for EC_tag in EC_tags:
237
+ for EC_num in EC_tag:
238
+ stmol.add_sphere(
239
+ xyzview,
240
+ spcenter=EC_num.coordinate,
241
+ radius=EC_num.radius,
242
+ spColor=EC_num.color,
243
+ )
244
 
245
  if label_resi:
246
  for hl_resi in hl_resi_list: