bshor commited on
Commit
f1046e9
·
1 Parent(s): c271a8f

add all model variations

Browse files
dockformer/config.py CHANGED
@@ -250,6 +250,9 @@ config = mlc.ConfigDict(
250
  "c_s": c_s,
251
  "num_bins": aux_affinity_bins,
252
  },
 
 
 
253
  "binding_site": {
254
  "c_s": c_s,
255
  "c_out": 1,
@@ -302,19 +305,22 @@ config = mlc.ConfigDict(
302
  "min_bin": 0,
303
  "max_bin": 15,
304
  "no_bins": aux_affinity_bins,
305
- "weight": 0.0,
306
  },
307
  "affinity1d": {
308
  "min_bin": 0,
309
  "max_bin": 15,
310
  "no_bins": aux_affinity_bins,
311
- "weight": 0.0,
312
  },
313
  "affinity_cls": {
314
  "min_bin": 0,
315
  "max_bin": 15,
316
  "no_bins": aux_affinity_bins,
317
- "weight": 0.0,
 
 
 
318
  },
319
  "fape_backbone": {
320
  "clamp_distance": 10.0,
 
250
  "c_s": c_s,
251
  "num_bins": aux_affinity_bins,
252
  },
253
+ "affinity_cls_reg": {
254
+ "c_s": c_s,
255
+ },
256
  "binding_site": {
257
  "c_s": c_s,
258
  "c_out": 1,
 
305
  "min_bin": 0,
306
  "max_bin": 15,
307
  "no_bins": aux_affinity_bins,
308
+ "weight": 0.03,
309
  },
310
  "affinity1d": {
311
  "min_bin": 0,
312
  "max_bin": 15,
313
  "no_bins": aux_affinity_bins,
314
+ "weight": 0.03,
315
  },
316
  "affinity_cls": {
317
  "min_bin": 0,
318
  "max_bin": 15,
319
  "no_bins": aux_affinity_bins,
320
+ "weight": 0.03,
321
+ },
322
+ "affinity_cls_reg": {
323
+ "weight": 0.03,
324
  },
325
  "fape_backbone": {
326
  "clamp_distance": 10.0,
dockformer/data/data_pipeline.py CHANGED
@@ -101,6 +101,18 @@ def _apply_protein_probablistic_transforms(tensors: FeatureTensorDict, cfg: mlc.
101
  return tensors
102
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  class DataPipeline:
105
  """Assembles input features."""
106
  def __init__(self, config: mlc.ConfigDict, mode: str):
@@ -200,37 +212,40 @@ class DataPipeline:
200
  raise ValueError(f"Unknown key in sdf list features {k}")
201
  return joined_ligand_feats
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  def get_matching_positions_list(self, ref_path_list: List[str], gt_path_list: List[str]):
204
  joined_gt_positions = []
205
 
206
  for ref_ligand_path, gt_ligand_path in zip(ref_path_list, gt_path_list):
207
- ref_ligand = Chem.MolFromMolFile(ref_ligand_path)
208
- gt_ligand = Chem.MolFromMolFile(gt_ligand_path)
209
-
210
- gt_original_positions = gt_ligand.GetConformer(0).GetPositions()
211
-
212
- gt_positions = [gt_original_positions[idx] for idx in gt_ligand.GetSubstructMatch(ref_ligand)]
213
 
214
  joined_gt_positions.extend(gt_positions)
215
 
216
  return torch.tensor(np.array(joined_gt_positions)).float()
217
 
218
  def get_matching_positions(self, ref_ligand_path: str, gt_ligand_path: str):
219
- ref_ligand = Chem.MolFromMolFile(ref_ligand_path)
220
- gt_ligand = Chem.MolFromMolFile(gt_ligand_path)
221
-
222
- gt_original_positions = gt_ligand.GetConformer(0).GetPositions()
223
-
224
- gt_positions = [gt_original_positions[idx] for idx in gt_ligand.GetSubstructMatch(ref_ligand)]
225
-
226
- # ref_positions = ref_ligand.GetConformer(0).GetPositions()
227
- # for i in range(len(ref_positions)):
228
- # for j in range(i + 1, len(ref_positions)):
229
- # dist_ref = np.linalg.norm(ref_positions[i] - ref_positions[j])
230
- # dist_gt = np.linalg.norm(gt_positions[i] - gt_positions[j])
231
- # dist_gt = np.linalg.norm(gt_original_positions[i] - gt_original_positions[j])
232
- # if abs(dist_ref - dist_gt) > 1.0:
233
- # print(f"Distance mismatch {i} {j} {dist_ref} {dist_gt}")
234
 
235
  return torch.tensor(np.array(gt_positions)) .float()
236
 
 
101
  return tensors
102
 
103
 
104
+ def get_psuedo_beta(pdb_path: str) -> torch.Tensor:
105
+ """Get pseudo beta positions for a protein."""
106
+ with open(pdb_path, 'r') as f:
107
+ pdb_str = f.read()
108
+ protein_object = protein.from_pdb_string(pdb_str)
109
+ pdb_feats = make_protein_features(protein_object, "")
110
+ tensor_feats = _np_filter_and_to_tensor_dict(pdb_feats, ["aatype", "all_atom_positions", "all_atom_mask"])
111
+ pdb_feats = _apply_protein_transforms(tensor_feats)
112
+
113
+ return pdb_feats["pseudo_beta"]
114
+
115
+
116
  class DataPipeline:
117
  """Assembles input features."""
118
  def __init__(self, config: mlc.ConfigDict, mode: str):
 
212
  raise ValueError(f"Unknown key in sdf list features {k}")
213
  return joined_ligand_feats
214
 
215
+ @staticmethod
216
+ def _get_gt_positions(ref_ligand_path: str, gt_ligand_path: str):
217
+ ref_ligand = Chem.MolFromMolFile(ref_ligand_path)
218
+ gt_ligand = Chem.MolFromMolFile(gt_ligand_path)
219
+ gt_original_positions = gt_ligand.GetConformer(0).GetPositions()
220
+ gt_positions = [gt_original_positions[idx] for idx in gt_ligand.GetSubstructMatch(ref_ligand)]
221
+
222
+ if len(gt_positions) == 0:
223
+ from rdkit.Chem import rdFMCS
224
+ mcs_result = rdFMCS.FindMCS([ref_ligand, gt_ligand])
225
+ if mcs_result.canceled:
226
+ print("MCS search canceled, Error!!!! Can't map ref ligand to gt ligand")
227
+ gt_positions = gt_original_positions
228
+ else:
229
+ mcs_mol = Chem.MolFromSmarts(mcs_result.smartsString)
230
+ ref_match = ref_ligand.GetSubstructMatch(mcs_mol)
231
+ gt_match = gt_ligand.GetSubstructMatch(mcs_mol)
232
+ ref_to_gt_atom = {ref_idx: gt_idx for ref_idx, gt_idx in zip(ref_match, gt_match)}
233
+ gt_positions = [gt_original_positions[ref_to_gt_atom[i]] for i in sorted(list(ref_to_gt_atom.keys()))]
234
+
235
+ return gt_positions
236
+
237
  def get_matching_positions_list(self, ref_path_list: List[str], gt_path_list: List[str]):
238
  joined_gt_positions = []
239
 
240
  for ref_ligand_path, gt_ligand_path in zip(ref_path_list, gt_path_list):
241
+ gt_positions = self.get_matching_positions(ref_ligand_path, gt_ligand_path)
 
 
 
 
 
242
 
243
  joined_gt_positions.extend(gt_positions)
244
 
245
  return torch.tensor(np.array(joined_gt_positions)).float()
246
 
247
  def get_matching_positions(self, ref_ligand_path: str, gt_ligand_path: str):
248
+ gt_positions = self._get_gt_positions(ref_ligand_path, gt_ligand_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
  return torch.tensor(np.array(gt_positions)) .float()
251
 
dockformer/model/heads.py CHANGED
@@ -50,6 +50,10 @@ class AuxiliaryHeads(nn.Module):
50
  **config["affinity_cls"],
51
  )
52
 
 
 
 
 
53
  self.binding_site = BindingSitePredictor(
54
  **config["binding_site"],
55
  )
@@ -60,7 +64,7 @@ class AuxiliaryHeads(nn.Module):
60
 
61
  self.config = config
62
 
63
- def forward(self, outputs, inter_mask, affinity_mask):
64
  aux_out = {}
65
  lddt_logits = self.plddt(outputs["sm"]["single"])
66
  aux_out["lddt_logits"] = lddt_logits
@@ -75,10 +79,12 @@ class AuxiliaryHeads(nn.Module):
75
 
76
  aux_out["affinity_2d_logits"] = self.affinity_2d(outputs["pair"], aux_out["inter_contact_logits"], inter_mask)
77
 
78
- aux_out["affinity_1d_logits"] = self.affinity_1d(outputs["single"])
79
 
80
  aux_out["affinity_cls_logits"] = self.affinity_cls(outputs["single"], affinity_mask)
81
 
 
 
82
  aux_out["binding_site_logits"] = self.binding_site(outputs["single"])
83
 
84
  return aux_out
@@ -120,18 +126,14 @@ class Affinity1DPredictor(nn.Module):
120
  self.c_s = c_s
121
 
122
  self.linear1 = Linear(self.c_s, self.c_s, init="final")
 
123
 
124
- self.linear2 = Linear(self.c_s, num_bins, init="final")
125
-
126
- def forward(self, s):
127
  # [*, N, C_out]
128
- s = self.linear1(s)
129
-
130
- # get an average over the sequence
131
- s = torch.mean(s, dim=1)
132
 
133
- logits = self.linear2(s)
134
- return logits
135
 
136
 
137
  class AffinityClsTokenPredictor(nn.Module):
@@ -146,6 +148,22 @@ class AffinityClsTokenPredictor(nn.Module):
146
  return self.linear(affinity_tokens)
147
 
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  class BindingSitePredictor(nn.Module):
150
  def __init__(self, c_s, c_out, **kwargs):
151
  super(BindingSitePredictor, self).__init__()
 
50
  **config["affinity_cls"],
51
  )
52
 
53
+ self.affinity_cls_reg = AffinityClsTokenPredictorRegression(
54
+ **config["affinity_cls_reg"],
55
+ )
56
+
57
  self.binding_site = BindingSitePredictor(
58
  **config["binding_site"],
59
  )
 
64
 
65
  self.config = config
66
 
67
+ def forward(self, outputs, inter_mask, affinity_mask, ligand_mask):
68
  aux_out = {}
69
  lddt_logits = self.plddt(outputs["sm"]["single"])
70
  aux_out["lddt_logits"] = lddt_logits
 
79
 
80
  aux_out["affinity_2d_logits"] = self.affinity_2d(outputs["pair"], aux_out["inter_contact_logits"], inter_mask)
81
 
82
+ aux_out["affinity_1d_logits"] = self.affinity_1d(outputs["single"], ligand_mask)
83
 
84
  aux_out["affinity_cls_logits"] = self.affinity_cls(outputs["single"], affinity_mask)
85
 
86
+ aux_out["affinity_cls_reg_logits"] = self.affinity_cls_reg(outputs["single"], affinity_mask)
87
+
88
  aux_out["binding_site_logits"] = self.binding_site(outputs["single"])
89
 
90
  return aux_out
 
126
  self.c_s = c_s
127
 
128
  self.linear1 = Linear(self.c_s, self.c_s, init="final")
129
+ self.out = Linear(self.c_s, num_bins, init="final")
130
 
131
+ def forward(self, s, ligand_mask):
 
 
132
  # [*, N, C_out]
133
+ s = nn.functional.relu(self.linear1(s))
134
+ mean_of_ligand = (s * ligand_mask.unsqueeze(-1)).sum(dim=1) / ligand_mask.sum(dim=1).unsqueeze(-1)
 
 
135
 
136
+ return self.out(mean_of_ligand)
 
137
 
138
 
139
  class AffinityClsTokenPredictor(nn.Module):
 
148
  return self.linear(affinity_tokens)
149
 
150
 
151
+ class AffinityClsTokenPredictorRegression(nn.Module):
152
+ def __init__(self, c_s, **kwargs):
153
+ super(AffinityClsTokenPredictorRegression, self).__init__()
154
+
155
+ self.c_s = c_s
156
+ self.fc1 = nn.Linear(self.c_s, self.c_s)
157
+ self.fc2 = nn.Linear(self.c_s, self.c_s)
158
+ self.out = nn.Linear(self.c_s, 1)
159
+
160
+ def forward(self, s, affinity_mask):
161
+ affinity_tokens = (s * affinity_mask.unsqueeze(-1)).sum(dim=1)
162
+ x = nn.functional.relu(self.fc1(affinity_tokens))
163
+ x = nn.functional.relu(self.fc2(x))
164
+ return self.out(x)
165
+
166
+
167
  class BindingSitePredictor(nn.Module):
168
  def __init__(self, c_s, c_out, **kwargs):
169
  super(BindingSitePredictor, self).__init__()
dockformer/model/model.py CHANGED
@@ -313,6 +313,7 @@ class AlphaFold(nn.Module):
313
  outputs["num_recycles"] = torch.tensor(num_recycles, device=feats["aatype"].device)
314
 
315
  # Run auxiliary heads, remove the recycling dimension batch properties
316
- outputs.update(self.aux_heads(outputs, batch["inter_pair_mask"][..., 0], batch["affinity_mask"][..., 0]))
 
317
 
318
  return outputs
 
313
  outputs["num_recycles"] = torch.tensor(num_recycles, device=feats["aatype"].device)
314
 
315
  # Run auxiliary heads, remove the recycling dimension batch properties
316
+ outputs.update(self.aux_heads(outputs, batch["inter_pair_mask"][..., 0], batch["affinity_mask"][..., 0],
317
+ batch["ligand_mask"][..., 0]))
318
 
319
  return outputs
dockformer/utils/loss.py CHANGED
@@ -670,6 +670,25 @@ def affinity_loss(
670
  # print("after factor", after_factor.shape, after_factor, affinity_loss_factor.sum(), mean_val)
671
  return mean_val
672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
673
 
674
  def positions_inter_distogram_loss(
675
  out,
@@ -1085,6 +1104,10 @@ class AlphaFoldLoss(nn.Module):
1085
  logits=out["affinity_cls_logits"],
1086
  **{**batch, **self.config.affinity_cls},
1087
  ),
 
 
 
 
1088
  "binding_site": lambda: binding_site_loss(
1089
  logits=out["binding_site_logits"],
1090
  **{**batch, **self.config.binding_site},
 
670
  # print("after factor", after_factor.shape, after_factor, affinity_loss_factor.sum(), mean_val)
671
  return mean_val
672
 
673
+ def affinity_loss_reg(
674
+ logits,
675
+ affinity,
676
+ affinity_loss_factor,
677
+ **kwargs,
678
+ ):
679
+ # apply mse loss
680
+ errors = torch.nn.functional.mse_loss(logits, affinity, reduction='none')
681
+
682
+ # print("errors dim", errors.shape, affinity_loss_factor.shape, errors)
683
+ after_factor = errors * affinity_loss_factor.squeeze()
684
+ if affinity_loss_factor.sum() > 0.1:
685
+ mean_val = after_factor.sum() / affinity_loss_factor.sum()
686
+ else:
687
+ # If no affinity in batch - get a very small loss. the factor also makes the loss small
688
+ mean_val = after_factor.sum() * 1e-3
689
+ # print("after factor", after_factor.shape, after_factor, affinity_loss_factor.sum(), mean_val)
690
+ return mean_val
691
+
692
 
693
  def positions_inter_distogram_loss(
694
  out,
 
1104
  logits=out["affinity_cls_logits"],
1105
  **{**batch, **self.config.affinity_cls},
1106
  ),
1107
+ "affinity_cls_reg": lambda: affinity_loss_reg(
1108
+ logits=out["affinity_cls_reg_logits"],
1109
+ **{**batch, **self.config.affinity_cls_reg},
1110
+ ),
1111
  "binding_site": lambda: binding_site_loss(
1112
  logits=out["binding_site_logits"],
1113
  **{**batch, **self.config.binding_site},
env_consts.py CHANGED
@@ -3,7 +3,12 @@ import os
3
  TEST_INPUT_DIR = None
4
  TEST_OUTPUT_DIR = None
5
  THIS_FILE_DIR = os.path.dirname(os.path.abspath(__file__))
6
- CKPT_PATH = os.path.join(THIS_FILE_DIR, "resources", "only_weights_114-191750.ckpt")
 
 
 
 
 
7
  RUN_CONFIG_PATH = os.path.join(THIS_FILE_DIR, "resources", "run_config.json")
8
 
9
  OUTPUT_PROT_PATH = os.path.join(THIS_FILE_DIR, "predicted_protein_out.pdb")
 
3
  TEST_INPUT_DIR = None
4
  TEST_OUTPUT_DIR = None
5
  THIS_FILE_DIR = os.path.dirname(os.path.abspath(__file__))
6
+ MODEL_NAME_TO_CKPT = {
7
+ "DockFormer-Screen": os.path.join(THIS_FILE_DIR, "resources", "dockformer_screen_102-110250.ckpt"),
8
+ "DockFormer-PDBBind": os.path.join(THIS_FILE_DIR, "resources", "dockformer_pdbbind_95-108500.ckpt"),
9
+ "DockFormer-PLINDER": os.path.join(THIS_FILE_DIR, "resources", "dockformer_plinder_132-98000.ckpt"),
10
+ }
11
+
12
  RUN_CONFIG_PATH = os.path.join(THIS_FILE_DIR, "resources", "run_config.json")
13
 
14
  OUTPUT_PROT_PATH = os.path.join(THIS_FILE_DIR, "predicted_protein_out.pdb")
inference_app.py CHANGED
@@ -4,16 +4,17 @@ import gradio as gr
4
 
5
  from gradio_molecule3d import Molecule3D
6
  from run_on_seq import run_on_sample_seqs
7
- from env_consts import RUN_CONFIG_PATH, OUTPUT_PROT_PATH, OUTPUT_LIG_PATH
8
 
9
 
10
- def predict(input_sequence, input_ligand, input_msa, input_protein):
11
  start_time = time.time()
12
  # Do inference here
13
  # return an output pdb file with the protein and ligand with resname LIG or UNK.
14
  # also return any metrics you want to log, metrics will not be used for evaluation but might be useful for users
 
15
  metrics = run_on_sample_seqs(input_sequence, input_protein, input_ligand, OUTPUT_PROT_PATH, OUTPUT_LIG_PATH,
16
- RUN_CONFIG_PATH)
17
  end_time = time.time()
18
  run_time = end_time - start_time
19
 
@@ -23,6 +24,12 @@ def predict(input_sequence, input_ligand, input_msa, input_protein):
23
  with gr.Blocks() as app:
24
  gr.Markdown("DockFormer")
25
 
 
 
 
 
 
 
26
  # gr.Markdown("Title, description, and other information about the model")
27
  with gr.Row():
28
  input_sequence = gr.Textbox(lines=3, label="Input Protein sequence (FASTA)")
 
4
 
5
  from gradio_molecule3d import Molecule3D
6
  from run_on_seq import run_on_sample_seqs
7
+ from env_consts import RUN_CONFIG_PATH, OUTPUT_PROT_PATH, OUTPUT_LIG_PATH, MODEL_NAME_TO_CKPT
8
 
9
 
10
+ def predict(input_sequence, input_ligand, input_msa, input_protein, model_variation):
11
  start_time = time.time()
12
  # Do inference here
13
  # return an output pdb file with the protein and ligand with resname LIG or UNK.
14
  # also return any metrics you want to log, metrics will not be used for evaluation but might be useful for users
15
+ ckpt_path = MODEL_NAME_TO_CKPT[model_variation]
16
  metrics = run_on_sample_seqs(input_sequence, input_protein, input_ligand, OUTPUT_PROT_PATH, OUTPUT_LIG_PATH,
17
+ RUN_CONFIG_PATH, ckpt_path)
18
  end_time = time.time()
19
  run_time = end_time - start_time
20
 
 
24
  with gr.Blocks() as app:
25
  gr.Markdown("DockFormer")
26
 
27
+ model_variation = gr.Dropdown(
28
+ choices=["DockFormer-Screen", "DockFormer-PDBBind", "DockFormer-PLINDER"],
29
+ label="Select model variation",
30
+ value="DockFormer-Screen" # Default value
31
+ )
32
+
33
  # gr.Markdown("Title, description, and other information about the model")
34
  with gr.Row():
35
  input_sequence = gr.Textbox(lines=3, label="Input Protein sequence (FASTA)")
resources/run_config.json CHANGED
@@ -16,6 +16,6 @@
16
  "affinity_cls": {"weight": 0.03},
17
  "fape_interface": {"weight": 1.0}
18
  },
19
- "globals": {"max_lr": 0.0001}
20
  }
21
  }
 
16
  "affinity_cls": {"weight": 0.03},
17
  "fape_interface": {"weight": 1.0}
18
  },
19
+ "globals": {"max_lr": 0.0002}
20
  }
21
  }
run_on_seq.py CHANGED
@@ -115,7 +115,7 @@ def create_embeded_molecule(ref_mol: Chem.Mol, smiles: str):
115
 
116
 
117
  def run_on_sample_seqs(seq_protein: str, template_protein_path: str, smiles: str, output_prot_path: str,
118
- output_lig_path: str, run_config_path: str):
119
  temp_dir = tempfile.TemporaryDirectory()
120
  temp_dir_path = temp_dir.name
121
  metrics = {}
@@ -132,7 +132,7 @@ def run_on_sample_seqs(seq_protein: str, template_protein_path: str, smiles: str
132
  json.dump(json_data, open(f"{tmp_json_folder}/input.json", "w"))
133
  tmp_output_folder = f"{temp_dir_path}/output"
134
 
135
- run_on_folder(tmp_json_folder, tmp_output_folder, run_config_path, skip_relaxation=True,
136
  long_sequence_inference=False, skip_exists=False)
137
  predicted_protein_path = tmp_output_folder + "/predictions/input_predicted_protein.pdb"
138
  predicted_ligand_path = tmp_output_folder + "/predictions/input_predicted_ligand_0.sdf"
 
115
 
116
 
117
  def run_on_sample_seqs(seq_protein: str, template_protein_path: str, smiles: str, output_prot_path: str,
118
+ output_lig_path: str, run_config_path: str, ckpt_path: str):
119
  temp_dir = tempfile.TemporaryDirectory()
120
  temp_dir_path = temp_dir.name
121
  metrics = {}
 
132
  json.dump(json_data, open(f"{tmp_json_folder}/input.json", "w"))
133
  tmp_output_folder = f"{temp_dir_path}/output"
134
 
135
+ run_on_folder(tmp_json_folder, tmp_output_folder, run_config_path, ckpt_path, skip_relaxation=True,
136
  long_sequence_inference=False, skip_exists=False)
137
  predicted_protein_path = tmp_output_folder + "/predictions/input_predicted_protein.pdb"
138
  predicted_ligand_path = tmp_output_folder + "/predictions/input_predicted_ligand_0.sdf"
run_pretrained_model.py CHANGED
@@ -14,7 +14,7 @@
14
  # limitations under the License.
15
  import sys
16
 
17
- from env_consts import TEST_INPUT_DIR, TEST_OUTPUT_DIR, CKPT_PATH
18
  import json
19
  import logging
20
  import numpy as np
@@ -59,7 +59,7 @@ def override_config(base_config, overriding_config):
59
  return base_config
60
 
61
 
62
- def run_on_folder(input_dir: str, output_dir: str, run_config_path: str, skip_relaxation=True,
63
  long_sequence_inference=False, skip_exists=False):
64
  config_preset = "initial_training"
65
  save_outputs = False
@@ -67,9 +67,7 @@ def run_on_folder(input_dir: str, output_dir: str, run_config_path: str, skip_re
67
 
68
  run_config = json.load(open(run_config_path))
69
 
70
- ckpt_path = CKPT_PATH
71
- if ckpt_path is None:
72
- ckpt_path = get_latest_checkpoint(os.path.join(run_config["train_output_dir"], "checkpoint"))
73
  print("Using checkpoint: ", ckpt_path)
74
 
75
  config = model_config(config_preset, long_sequence_inference=long_sequence_inference)
@@ -115,17 +113,29 @@ def run_on_folder(input_dir: str, output_dir: str, run_config_path: str, skip_re
115
  dim=-1).item()
116
  affinity_cls = torch.sum(torch.softmax(torch.tensor(out["affinity_cls_logits"]), -1) * torch.linspace(0, 15, 32),
117
  dim=-1).item()
118
-
119
 
120
  affinity_2d_max = torch.linspace(0, 15, 32)[torch.argmax(torch.tensor(out["affinity_2d_logits"]))].item()
121
  affinity_1d_max = torch.linspace(0, 15, 32)[torch.argmax(torch.tensor(out["affinity_1d_logits"]))].item()
122
  affinity_cls_max = torch.linspace(0, 15, 32)[torch.argmax(torch.tensor(out["affinity_cls_logits"]))].item()
123
 
 
 
 
 
 
 
 
 
 
 
 
124
  print("Affinity: ", affinity_2d, affinity_cls, affinity_1d)
125
  with open(affinity_output_path, "w") as f:
126
  json.dump({"affinity_2d": affinity_2d, "affinity_1d": affinity_1d, "affinity_cls": affinity_cls,
127
  "affinity_2d_max": affinity_2d_max, "affinity_1d_max": affinity_1d_max,
128
- "affinity_cls_max": affinity_cls_max}, f)
 
129
 
130
  # binding_site = torch.sigmoid(torch.tensor(out["binding_site_logits"])) * 100
131
  # binding_site = binding_site[:processed_feature_dict["aatype"].shape[1]].flatten()
@@ -135,9 +145,6 @@ def run_on_folder(input_dir: str, output_dir: str, run_config_path: str, skip_re
135
 
136
  ligand_output_path = os.path.join(output_directory, f"{output_name}_ligand_{{i}}.sdf")
137
 
138
- protein_mask = processed_feature_dict["protein_mask"][0].astype(bool)
139
- ligand_mask = processed_feature_dict["ligand_mask"][0].astype(bool)
140
-
141
  save_output_structure(
142
  aatype=processed_feature_dict["aatype"][0][protein_mask],
143
  residue_index=processed_feature_dict["in_chain_residue_index"][0],
 
14
  # limitations under the License.
15
  import sys
16
 
17
+ from env_consts import TEST_INPUT_DIR, TEST_OUTPUT_DIR
18
  import json
19
  import logging
20
  import numpy as np
 
59
  return base_config
60
 
61
 
62
+ def run_on_folder(input_dir: str, output_dir: str, run_config_path: str, ckpt_path: str, skip_relaxation=True,
63
  long_sequence_inference=False, skip_exists=False):
64
  config_preset = "initial_training"
65
  save_outputs = False
 
67
 
68
  run_config = json.load(open(run_config_path))
69
 
70
+ ckpt_path = os.path.abspath(ckpt_path)
 
 
71
  print("Using checkpoint: ", ckpt_path)
72
 
73
  config = model_config(config_preset, long_sequence_inference=long_sequence_inference)
 
113
  dim=-1).item()
114
  affinity_cls = torch.sum(torch.softmax(torch.tensor(out["affinity_cls_logits"]), -1) * torch.linspace(0, 15, 32),
115
  dim=-1).item()
116
+ affinity_cls_reg = torch.tensor(out["affinity_cls_reg_logits"]).item()
117
 
118
  affinity_2d_max = torch.linspace(0, 15, 32)[torch.argmax(torch.tensor(out["affinity_2d_logits"]))].item()
119
  affinity_1d_max = torch.linspace(0, 15, 32)[torch.argmax(torch.tensor(out["affinity_1d_logits"]))].item()
120
  affinity_cls_max = torch.linspace(0, 15, 32)[torch.argmax(torch.tensor(out["affinity_cls_logits"]))].item()
121
 
122
+ protein_mask = processed_feature_dict["protein_mask"][0].astype(bool)
123
+ ligand_mask = processed_feature_dict["ligand_mask"][0].astype(bool)
124
+
125
+ protein_length = protein_mask.sum()
126
+ ligand_length = ligand_mask.sum()
127
+ predicted_inter_contacts_logits = torch.tensor(out["inter_contact_logits"][0][:protein_length,
128
+ protein_length:protein_length+ligand_length, :])
129
+
130
+ top_100_inter_contacts = torch.topk(predicted_inter_contacts_logits.flatten(), 100).indices
131
+ inter_contacts_indices = [[int(i // ligand_length), int(i % ligand_length)] for i in top_100_inter_contacts]
132
+
133
  print("Affinity: ", affinity_2d, affinity_cls, affinity_1d)
134
  with open(affinity_output_path, "w") as f:
135
  json.dump({"affinity_2d": affinity_2d, "affinity_1d": affinity_1d, "affinity_cls": affinity_cls,
136
  "affinity_2d_max": affinity_2d_max, "affinity_1d_max": affinity_1d_max,
137
+ "affinity_cls_max": affinity_cls_max, "affinity_cls_reg": affinity_cls_reg,
138
+ "inter_contacts": inter_contacts_indices}, f)
139
 
140
  # binding_site = torch.sigmoid(torch.tensor(out["binding_site_logits"])) * 100
141
  # binding_site = binding_site[:processed_feature_dict["aatype"].shape[1]].flatten()
 
145
 
146
  ligand_output_path = os.path.join(output_directory, f"{output_name}_ligand_{{i}}.sdf")
147
 
 
 
 
148
  save_output_structure(
149
  aatype=processed_feature_dict["aatype"][0][protein_mask],
150
  residue_index=processed_feature_dict["in_chain_residue_index"][0],