Spaces:
Running
Running
add all model variations
Browse files- dockformer/config.py +9 -3
- dockformer/data/data_pipeline.py +36 -21
- dockformer/model/heads.py +29 -11
- dockformer/model/model.py +2 -1
- dockformer/utils/loss.py +23 -0
- env_consts.py +6 -1
- inference_app.py +10 -3
- resources/run_config.json +1 -1
- run_on_seq.py +2 -2
- run_pretrained_model.py +17 -10
dockformer/config.py
CHANGED
@@ -250,6 +250,9 @@ config = mlc.ConfigDict(
|
|
250 |
"c_s": c_s,
|
251 |
"num_bins": aux_affinity_bins,
|
252 |
},
|
|
|
|
|
|
|
253 |
"binding_site": {
|
254 |
"c_s": c_s,
|
255 |
"c_out": 1,
|
@@ -302,19 +305,22 @@ config = mlc.ConfigDict(
|
|
302 |
"min_bin": 0,
|
303 |
"max_bin": 15,
|
304 |
"no_bins": aux_affinity_bins,
|
305 |
-
"weight": 0.
|
306 |
},
|
307 |
"affinity1d": {
|
308 |
"min_bin": 0,
|
309 |
"max_bin": 15,
|
310 |
"no_bins": aux_affinity_bins,
|
311 |
-
"weight": 0.
|
312 |
},
|
313 |
"affinity_cls": {
|
314 |
"min_bin": 0,
|
315 |
"max_bin": 15,
|
316 |
"no_bins": aux_affinity_bins,
|
317 |
-
"weight": 0.
|
|
|
|
|
|
|
318 |
},
|
319 |
"fape_backbone": {
|
320 |
"clamp_distance": 10.0,
|
|
|
250 |
"c_s": c_s,
|
251 |
"num_bins": aux_affinity_bins,
|
252 |
},
|
253 |
+
"affinity_cls_reg": {
|
254 |
+
"c_s": c_s,
|
255 |
+
},
|
256 |
"binding_site": {
|
257 |
"c_s": c_s,
|
258 |
"c_out": 1,
|
|
|
305 |
"min_bin": 0,
|
306 |
"max_bin": 15,
|
307 |
"no_bins": aux_affinity_bins,
|
308 |
+
"weight": 0.03,
|
309 |
},
|
310 |
"affinity1d": {
|
311 |
"min_bin": 0,
|
312 |
"max_bin": 15,
|
313 |
"no_bins": aux_affinity_bins,
|
314 |
+
"weight": 0.03,
|
315 |
},
|
316 |
"affinity_cls": {
|
317 |
"min_bin": 0,
|
318 |
"max_bin": 15,
|
319 |
"no_bins": aux_affinity_bins,
|
320 |
+
"weight": 0.03,
|
321 |
+
},
|
322 |
+
"affinity_cls_reg": {
|
323 |
+
"weight": 0.03,
|
324 |
},
|
325 |
"fape_backbone": {
|
326 |
"clamp_distance": 10.0,
|
dockformer/data/data_pipeline.py
CHANGED
@@ -101,6 +101,18 @@ def _apply_protein_probablistic_transforms(tensors: FeatureTensorDict, cfg: mlc.
|
|
101 |
return tensors
|
102 |
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
class DataPipeline:
|
105 |
"""Assembles input features."""
|
106 |
def __init__(self, config: mlc.ConfigDict, mode: str):
|
@@ -200,37 +212,40 @@ class DataPipeline:
|
|
200 |
raise ValueError(f"Unknown key in sdf list features {k}")
|
201 |
return joined_ligand_feats
|
202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
def get_matching_positions_list(self, ref_path_list: List[str], gt_path_list: List[str]):
|
204 |
joined_gt_positions = []
|
205 |
|
206 |
for ref_ligand_path, gt_ligand_path in zip(ref_path_list, gt_path_list):
|
207 |
-
|
208 |
-
gt_ligand = Chem.MolFromMolFile(gt_ligand_path)
|
209 |
-
|
210 |
-
gt_original_positions = gt_ligand.GetConformer(0).GetPositions()
|
211 |
-
|
212 |
-
gt_positions = [gt_original_positions[idx] for idx in gt_ligand.GetSubstructMatch(ref_ligand)]
|
213 |
|
214 |
joined_gt_positions.extend(gt_positions)
|
215 |
|
216 |
return torch.tensor(np.array(joined_gt_positions)).float()
|
217 |
|
218 |
def get_matching_positions(self, ref_ligand_path: str, gt_ligand_path: str):
|
219 |
-
|
220 |
-
gt_ligand = Chem.MolFromMolFile(gt_ligand_path)
|
221 |
-
|
222 |
-
gt_original_positions = gt_ligand.GetConformer(0).GetPositions()
|
223 |
-
|
224 |
-
gt_positions = [gt_original_positions[idx] for idx in gt_ligand.GetSubstructMatch(ref_ligand)]
|
225 |
-
|
226 |
-
# ref_positions = ref_ligand.GetConformer(0).GetPositions()
|
227 |
-
# for i in range(len(ref_positions)):
|
228 |
-
# for j in range(i + 1, len(ref_positions)):
|
229 |
-
# dist_ref = np.linalg.norm(ref_positions[i] - ref_positions[j])
|
230 |
-
# dist_gt = np.linalg.norm(gt_positions[i] - gt_positions[j])
|
231 |
-
# dist_gt = np.linalg.norm(gt_original_positions[i] - gt_original_positions[j])
|
232 |
-
# if abs(dist_ref - dist_gt) > 1.0:
|
233 |
-
# print(f"Distance mismatch {i} {j} {dist_ref} {dist_gt}")
|
234 |
|
235 |
return torch.tensor(np.array(gt_positions)) .float()
|
236 |
|
|
|
101 |
return tensors
|
102 |
|
103 |
|
104 |
+
def get_psuedo_beta(pdb_path: str) -> torch.Tensor:
|
105 |
+
"""Get pseudo beta positions for a protein."""
|
106 |
+
with open(pdb_path, 'r') as f:
|
107 |
+
pdb_str = f.read()
|
108 |
+
protein_object = protein.from_pdb_string(pdb_str)
|
109 |
+
pdb_feats = make_protein_features(protein_object, "")
|
110 |
+
tensor_feats = _np_filter_and_to_tensor_dict(pdb_feats, ["aatype", "all_atom_positions", "all_atom_mask"])
|
111 |
+
pdb_feats = _apply_protein_transforms(tensor_feats)
|
112 |
+
|
113 |
+
return pdb_feats["pseudo_beta"]
|
114 |
+
|
115 |
+
|
116 |
class DataPipeline:
|
117 |
"""Assembles input features."""
|
118 |
def __init__(self, config: mlc.ConfigDict, mode: str):
|
|
|
212 |
raise ValueError(f"Unknown key in sdf list features {k}")
|
213 |
return joined_ligand_feats
|
214 |
|
215 |
+
@staticmethod
|
216 |
+
def _get_gt_positions(ref_ligand_path: str, gt_ligand_path: str):
|
217 |
+
ref_ligand = Chem.MolFromMolFile(ref_ligand_path)
|
218 |
+
gt_ligand = Chem.MolFromMolFile(gt_ligand_path)
|
219 |
+
gt_original_positions = gt_ligand.GetConformer(0).GetPositions()
|
220 |
+
gt_positions = [gt_original_positions[idx] for idx in gt_ligand.GetSubstructMatch(ref_ligand)]
|
221 |
+
|
222 |
+
if len(gt_positions) == 0:
|
223 |
+
from rdkit.Chem import rdFMCS
|
224 |
+
mcs_result = rdFMCS.FindMCS([ref_ligand, gt_ligand])
|
225 |
+
if mcs_result.canceled:
|
226 |
+
print("MCS search canceled, Error!!!! Can't map ref ligand to gt ligand")
|
227 |
+
gt_positions = gt_original_positions
|
228 |
+
else:
|
229 |
+
mcs_mol = Chem.MolFromSmarts(mcs_result.smartsString)
|
230 |
+
ref_match = ref_ligand.GetSubstructMatch(mcs_mol)
|
231 |
+
gt_match = gt_ligand.GetSubstructMatch(mcs_mol)
|
232 |
+
ref_to_gt_atom = {ref_idx: gt_idx for ref_idx, gt_idx in zip(ref_match, gt_match)}
|
233 |
+
gt_positions = [gt_original_positions[ref_to_gt_atom[i]] for i in sorted(list(ref_to_gt_atom.keys()))]
|
234 |
+
|
235 |
+
return gt_positions
|
236 |
+
|
237 |
def get_matching_positions_list(self, ref_path_list: List[str], gt_path_list: List[str]):
|
238 |
joined_gt_positions = []
|
239 |
|
240 |
for ref_ligand_path, gt_ligand_path in zip(ref_path_list, gt_path_list):
|
241 |
+
gt_positions = self.get_matching_positions(ref_ligand_path, gt_ligand_path)
|
|
|
|
|
|
|
|
|
|
|
242 |
|
243 |
joined_gt_positions.extend(gt_positions)
|
244 |
|
245 |
return torch.tensor(np.array(joined_gt_positions)).float()
|
246 |
|
247 |
def get_matching_positions(self, ref_ligand_path: str, gt_ligand_path: str):
|
248 |
+
gt_positions = self._get_gt_positions(ref_ligand_path, gt_ligand_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
|
250 |
return torch.tensor(np.array(gt_positions)) .float()
|
251 |
|
dockformer/model/heads.py
CHANGED
@@ -50,6 +50,10 @@ class AuxiliaryHeads(nn.Module):
|
|
50 |
**config["affinity_cls"],
|
51 |
)
|
52 |
|
|
|
|
|
|
|
|
|
53 |
self.binding_site = BindingSitePredictor(
|
54 |
**config["binding_site"],
|
55 |
)
|
@@ -60,7 +64,7 @@ class AuxiliaryHeads(nn.Module):
|
|
60 |
|
61 |
self.config = config
|
62 |
|
63 |
-
def forward(self, outputs, inter_mask, affinity_mask):
|
64 |
aux_out = {}
|
65 |
lddt_logits = self.plddt(outputs["sm"]["single"])
|
66 |
aux_out["lddt_logits"] = lddt_logits
|
@@ -75,10 +79,12 @@ class AuxiliaryHeads(nn.Module):
|
|
75 |
|
76 |
aux_out["affinity_2d_logits"] = self.affinity_2d(outputs["pair"], aux_out["inter_contact_logits"], inter_mask)
|
77 |
|
78 |
-
aux_out["affinity_1d_logits"] = self.affinity_1d(outputs["single"])
|
79 |
|
80 |
aux_out["affinity_cls_logits"] = self.affinity_cls(outputs["single"], affinity_mask)
|
81 |
|
|
|
|
|
82 |
aux_out["binding_site_logits"] = self.binding_site(outputs["single"])
|
83 |
|
84 |
return aux_out
|
@@ -120,18 +126,14 @@ class Affinity1DPredictor(nn.Module):
|
|
120 |
self.c_s = c_s
|
121 |
|
122 |
self.linear1 = Linear(self.c_s, self.c_s, init="final")
|
|
|
123 |
|
124 |
-
|
125 |
-
|
126 |
-
def forward(self, s):
|
127 |
# [*, N, C_out]
|
128 |
-
s = self.linear1(s)
|
129 |
-
|
130 |
-
# get an average over the sequence
|
131 |
-
s = torch.mean(s, dim=1)
|
132 |
|
133 |
-
|
134 |
-
return logits
|
135 |
|
136 |
|
137 |
class AffinityClsTokenPredictor(nn.Module):
|
@@ -146,6 +148,22 @@ class AffinityClsTokenPredictor(nn.Module):
|
|
146 |
return self.linear(affinity_tokens)
|
147 |
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
class BindingSitePredictor(nn.Module):
|
150 |
def __init__(self, c_s, c_out, **kwargs):
|
151 |
super(BindingSitePredictor, self).__init__()
|
|
|
50 |
**config["affinity_cls"],
|
51 |
)
|
52 |
|
53 |
+
self.affinity_cls_reg = AffinityClsTokenPredictorRegression(
|
54 |
+
**config["affinity_cls_reg"],
|
55 |
+
)
|
56 |
+
|
57 |
self.binding_site = BindingSitePredictor(
|
58 |
**config["binding_site"],
|
59 |
)
|
|
|
64 |
|
65 |
self.config = config
|
66 |
|
67 |
+
def forward(self, outputs, inter_mask, affinity_mask, ligand_mask):
|
68 |
aux_out = {}
|
69 |
lddt_logits = self.plddt(outputs["sm"]["single"])
|
70 |
aux_out["lddt_logits"] = lddt_logits
|
|
|
79 |
|
80 |
aux_out["affinity_2d_logits"] = self.affinity_2d(outputs["pair"], aux_out["inter_contact_logits"], inter_mask)
|
81 |
|
82 |
+
aux_out["affinity_1d_logits"] = self.affinity_1d(outputs["single"], ligand_mask)
|
83 |
|
84 |
aux_out["affinity_cls_logits"] = self.affinity_cls(outputs["single"], affinity_mask)
|
85 |
|
86 |
+
aux_out["affinity_cls_reg_logits"] = self.affinity_cls_reg(outputs["single"], affinity_mask)
|
87 |
+
|
88 |
aux_out["binding_site_logits"] = self.binding_site(outputs["single"])
|
89 |
|
90 |
return aux_out
|
|
|
126 |
self.c_s = c_s
|
127 |
|
128 |
self.linear1 = Linear(self.c_s, self.c_s, init="final")
|
129 |
+
self.out = Linear(self.c_s, num_bins, init="final")
|
130 |
|
131 |
+
def forward(self, s, ligand_mask):
|
|
|
|
|
132 |
# [*, N, C_out]
|
133 |
+
s = nn.functional.relu(self.linear1(s))
|
134 |
+
mean_of_ligand = (s * ligand_mask.unsqueeze(-1)).sum(dim=1) / ligand_mask.sum(dim=1).unsqueeze(-1)
|
|
|
|
|
135 |
|
136 |
+
return self.out(mean_of_ligand)
|
|
|
137 |
|
138 |
|
139 |
class AffinityClsTokenPredictor(nn.Module):
|
|
|
148 |
return self.linear(affinity_tokens)
|
149 |
|
150 |
|
151 |
+
class AffinityClsTokenPredictorRegression(nn.Module):
|
152 |
+
def __init__(self, c_s, **kwargs):
|
153 |
+
super(AffinityClsTokenPredictorRegression, self).__init__()
|
154 |
+
|
155 |
+
self.c_s = c_s
|
156 |
+
self.fc1 = nn.Linear(self.c_s, self.c_s)
|
157 |
+
self.fc2 = nn.Linear(self.c_s, self.c_s)
|
158 |
+
self.out = nn.Linear(self.c_s, 1)
|
159 |
+
|
160 |
+
def forward(self, s, affinity_mask):
|
161 |
+
affinity_tokens = (s * affinity_mask.unsqueeze(-1)).sum(dim=1)
|
162 |
+
x = nn.functional.relu(self.fc1(affinity_tokens))
|
163 |
+
x = nn.functional.relu(self.fc2(x))
|
164 |
+
return self.out(x)
|
165 |
+
|
166 |
+
|
167 |
class BindingSitePredictor(nn.Module):
|
168 |
def __init__(self, c_s, c_out, **kwargs):
|
169 |
super(BindingSitePredictor, self).__init__()
|
dockformer/model/model.py
CHANGED
@@ -313,6 +313,7 @@ class AlphaFold(nn.Module):
|
|
313 |
outputs["num_recycles"] = torch.tensor(num_recycles, device=feats["aatype"].device)
|
314 |
|
315 |
# Run auxiliary heads, remove the recycling dimension batch properties
|
316 |
-
outputs.update(self.aux_heads(outputs, batch["inter_pair_mask"][..., 0], batch["affinity_mask"][..., 0]
|
|
|
317 |
|
318 |
return outputs
|
|
|
313 |
outputs["num_recycles"] = torch.tensor(num_recycles, device=feats["aatype"].device)
|
314 |
|
315 |
# Run auxiliary heads, remove the recycling dimension batch properties
|
316 |
+
outputs.update(self.aux_heads(outputs, batch["inter_pair_mask"][..., 0], batch["affinity_mask"][..., 0],
|
317 |
+
batch["ligand_mask"][..., 0]))
|
318 |
|
319 |
return outputs
|
dockformer/utils/loss.py
CHANGED
@@ -670,6 +670,25 @@ def affinity_loss(
|
|
670 |
# print("after factor", after_factor.shape, after_factor, affinity_loss_factor.sum(), mean_val)
|
671 |
return mean_val
|
672 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
673 |
|
674 |
def positions_inter_distogram_loss(
|
675 |
out,
|
@@ -1085,6 +1104,10 @@ class AlphaFoldLoss(nn.Module):
|
|
1085 |
logits=out["affinity_cls_logits"],
|
1086 |
**{**batch, **self.config.affinity_cls},
|
1087 |
),
|
|
|
|
|
|
|
|
|
1088 |
"binding_site": lambda: binding_site_loss(
|
1089 |
logits=out["binding_site_logits"],
|
1090 |
**{**batch, **self.config.binding_site},
|
|
|
670 |
# print("after factor", after_factor.shape, after_factor, affinity_loss_factor.sum(), mean_val)
|
671 |
return mean_val
|
672 |
|
673 |
+
def affinity_loss_reg(
|
674 |
+
logits,
|
675 |
+
affinity,
|
676 |
+
affinity_loss_factor,
|
677 |
+
**kwargs,
|
678 |
+
):
|
679 |
+
# apply mse loss
|
680 |
+
errors = torch.nn.functional.mse_loss(logits, affinity, reduction='none')
|
681 |
+
|
682 |
+
# print("errors dim", errors.shape, affinity_loss_factor.shape, errors)
|
683 |
+
after_factor = errors * affinity_loss_factor.squeeze()
|
684 |
+
if affinity_loss_factor.sum() > 0.1:
|
685 |
+
mean_val = after_factor.sum() / affinity_loss_factor.sum()
|
686 |
+
else:
|
687 |
+
# If no affinity in batch - get a very small loss. the factor also makes the loss small
|
688 |
+
mean_val = after_factor.sum() * 1e-3
|
689 |
+
# print("after factor", after_factor.shape, after_factor, affinity_loss_factor.sum(), mean_val)
|
690 |
+
return mean_val
|
691 |
+
|
692 |
|
693 |
def positions_inter_distogram_loss(
|
694 |
out,
|
|
|
1104 |
logits=out["affinity_cls_logits"],
|
1105 |
**{**batch, **self.config.affinity_cls},
|
1106 |
),
|
1107 |
+
"affinity_cls_reg": lambda: affinity_loss_reg(
|
1108 |
+
logits=out["affinity_cls_reg_logits"],
|
1109 |
+
**{**batch, **self.config.affinity_cls_reg},
|
1110 |
+
),
|
1111 |
"binding_site": lambda: binding_site_loss(
|
1112 |
logits=out["binding_site_logits"],
|
1113 |
**{**batch, **self.config.binding_site},
|
env_consts.py
CHANGED
@@ -3,7 +3,12 @@ import os
|
|
3 |
TEST_INPUT_DIR = None
|
4 |
TEST_OUTPUT_DIR = None
|
5 |
THIS_FILE_DIR = os.path.dirname(os.path.abspath(__file__))
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
7 |
RUN_CONFIG_PATH = os.path.join(THIS_FILE_DIR, "resources", "run_config.json")
|
8 |
|
9 |
OUTPUT_PROT_PATH = os.path.join(THIS_FILE_DIR, "predicted_protein_out.pdb")
|
|
|
3 |
TEST_INPUT_DIR = None
|
4 |
TEST_OUTPUT_DIR = None
|
5 |
THIS_FILE_DIR = os.path.dirname(os.path.abspath(__file__))
|
6 |
+
MODEL_NAME_TO_CKPT = {
|
7 |
+
"DockFormer-Screen": os.path.join(THIS_FILE_DIR, "resources", "dockformer_screen_102-110250.ckpt"),
|
8 |
+
"DockFormer-PDBBind": os.path.join(THIS_FILE_DIR, "resources", "dockformer_pdbbind_95-108500.ckpt"),
|
9 |
+
"DockFormer-PLINDER": os.path.join(THIS_FILE_DIR, "resources", "dockformer_plinder_132-98000.ckpt"),
|
10 |
+
}
|
11 |
+
|
12 |
RUN_CONFIG_PATH = os.path.join(THIS_FILE_DIR, "resources", "run_config.json")
|
13 |
|
14 |
OUTPUT_PROT_PATH = os.path.join(THIS_FILE_DIR, "predicted_protein_out.pdb")
|
inference_app.py
CHANGED
@@ -4,16 +4,17 @@ import gradio as gr
|
|
4 |
|
5 |
from gradio_molecule3d import Molecule3D
|
6 |
from run_on_seq import run_on_sample_seqs
|
7 |
-
from env_consts import RUN_CONFIG_PATH, OUTPUT_PROT_PATH, OUTPUT_LIG_PATH
|
8 |
|
9 |
|
10 |
-
def predict(input_sequence, input_ligand, input_msa, input_protein):
|
11 |
start_time = time.time()
|
12 |
# Do inference here
|
13 |
# return an output pdb file with the protein and ligand with resname LIG or UNK.
|
14 |
# also return any metrics you want to log, metrics will not be used for evaluation but might be useful for users
|
|
|
15 |
metrics = run_on_sample_seqs(input_sequence, input_protein, input_ligand, OUTPUT_PROT_PATH, OUTPUT_LIG_PATH,
|
16 |
-
RUN_CONFIG_PATH)
|
17 |
end_time = time.time()
|
18 |
run_time = end_time - start_time
|
19 |
|
@@ -23,6 +24,12 @@ def predict(input_sequence, input_ligand, input_msa, input_protein):
|
|
23 |
with gr.Blocks() as app:
|
24 |
gr.Markdown("DockFormer")
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
# gr.Markdown("Title, description, and other information about the model")
|
27 |
with gr.Row():
|
28 |
input_sequence = gr.Textbox(lines=3, label="Input Protein sequence (FASTA)")
|
|
|
4 |
|
5 |
from gradio_molecule3d import Molecule3D
|
6 |
from run_on_seq import run_on_sample_seqs
|
7 |
+
from env_consts import RUN_CONFIG_PATH, OUTPUT_PROT_PATH, OUTPUT_LIG_PATH, MODEL_NAME_TO_CKPT
|
8 |
|
9 |
|
10 |
+
def predict(input_sequence, input_ligand, input_msa, input_protein, model_variation):
|
11 |
start_time = time.time()
|
12 |
# Do inference here
|
13 |
# return an output pdb file with the protein and ligand with resname LIG or UNK.
|
14 |
# also return any metrics you want to log, metrics will not be used for evaluation but might be useful for users
|
15 |
+
ckpt_path = MODEL_NAME_TO_CKPT[model_variation]
|
16 |
metrics = run_on_sample_seqs(input_sequence, input_protein, input_ligand, OUTPUT_PROT_PATH, OUTPUT_LIG_PATH,
|
17 |
+
RUN_CONFIG_PATH, ckpt_path)
|
18 |
end_time = time.time()
|
19 |
run_time = end_time - start_time
|
20 |
|
|
|
24 |
with gr.Blocks() as app:
|
25 |
gr.Markdown("DockFormer")
|
26 |
|
27 |
+
model_variation = gr.Dropdown(
|
28 |
+
choices=["DockFormer-Screen", "DockFormer-PDBBind", "DockFormer-PLINDER"],
|
29 |
+
label="Select model variation",
|
30 |
+
value="DockFormer-Screen" # Default value
|
31 |
+
)
|
32 |
+
|
33 |
# gr.Markdown("Title, description, and other information about the model")
|
34 |
with gr.Row():
|
35 |
input_sequence = gr.Textbox(lines=3, label="Input Protein sequence (FASTA)")
|
resources/run_config.json
CHANGED
@@ -16,6 +16,6 @@
|
|
16 |
"affinity_cls": {"weight": 0.03},
|
17 |
"fape_interface": {"weight": 1.0}
|
18 |
},
|
19 |
-
"globals": {"max_lr": 0.
|
20 |
}
|
21 |
}
|
|
|
16 |
"affinity_cls": {"weight": 0.03},
|
17 |
"fape_interface": {"weight": 1.0}
|
18 |
},
|
19 |
+
"globals": {"max_lr": 0.0002}
|
20 |
}
|
21 |
}
|
run_on_seq.py
CHANGED
@@ -115,7 +115,7 @@ def create_embeded_molecule(ref_mol: Chem.Mol, smiles: str):
|
|
115 |
|
116 |
|
117 |
def run_on_sample_seqs(seq_protein: str, template_protein_path: str, smiles: str, output_prot_path: str,
|
118 |
-
output_lig_path: str, run_config_path: str):
|
119 |
temp_dir = tempfile.TemporaryDirectory()
|
120 |
temp_dir_path = temp_dir.name
|
121 |
metrics = {}
|
@@ -132,7 +132,7 @@ def run_on_sample_seqs(seq_protein: str, template_protein_path: str, smiles: str
|
|
132 |
json.dump(json_data, open(f"{tmp_json_folder}/input.json", "w"))
|
133 |
tmp_output_folder = f"{temp_dir_path}/output"
|
134 |
|
135 |
-
run_on_folder(tmp_json_folder, tmp_output_folder, run_config_path, skip_relaxation=True,
|
136 |
long_sequence_inference=False, skip_exists=False)
|
137 |
predicted_protein_path = tmp_output_folder + "/predictions/input_predicted_protein.pdb"
|
138 |
predicted_ligand_path = tmp_output_folder + "/predictions/input_predicted_ligand_0.sdf"
|
|
|
115 |
|
116 |
|
117 |
def run_on_sample_seqs(seq_protein: str, template_protein_path: str, smiles: str, output_prot_path: str,
|
118 |
+
output_lig_path: str, run_config_path: str, ckpt_path: str):
|
119 |
temp_dir = tempfile.TemporaryDirectory()
|
120 |
temp_dir_path = temp_dir.name
|
121 |
metrics = {}
|
|
|
132 |
json.dump(json_data, open(f"{tmp_json_folder}/input.json", "w"))
|
133 |
tmp_output_folder = f"{temp_dir_path}/output"
|
134 |
|
135 |
+
run_on_folder(tmp_json_folder, tmp_output_folder, run_config_path, ckpt_path, skip_relaxation=True,
|
136 |
long_sequence_inference=False, skip_exists=False)
|
137 |
predicted_protein_path = tmp_output_folder + "/predictions/input_predicted_protein.pdb"
|
138 |
predicted_ligand_path = tmp_output_folder + "/predictions/input_predicted_ligand_0.sdf"
|
run_pretrained_model.py
CHANGED
@@ -14,7 +14,7 @@
|
|
14 |
# limitations under the License.
|
15 |
import sys
|
16 |
|
17 |
-
from env_consts import TEST_INPUT_DIR, TEST_OUTPUT_DIR
|
18 |
import json
|
19 |
import logging
|
20 |
import numpy as np
|
@@ -59,7 +59,7 @@ def override_config(base_config, overriding_config):
|
|
59 |
return base_config
|
60 |
|
61 |
|
62 |
-
def run_on_folder(input_dir: str, output_dir: str, run_config_path: str, skip_relaxation=True,
|
63 |
long_sequence_inference=False, skip_exists=False):
|
64 |
config_preset = "initial_training"
|
65 |
save_outputs = False
|
@@ -67,9 +67,7 @@ def run_on_folder(input_dir: str, output_dir: str, run_config_path: str, skip_re
|
|
67 |
|
68 |
run_config = json.load(open(run_config_path))
|
69 |
|
70 |
-
ckpt_path =
|
71 |
-
if ckpt_path is None:
|
72 |
-
ckpt_path = get_latest_checkpoint(os.path.join(run_config["train_output_dir"], "checkpoint"))
|
73 |
print("Using checkpoint: ", ckpt_path)
|
74 |
|
75 |
config = model_config(config_preset, long_sequence_inference=long_sequence_inference)
|
@@ -115,17 +113,29 @@ def run_on_folder(input_dir: str, output_dir: str, run_config_path: str, skip_re
|
|
115 |
dim=-1).item()
|
116 |
affinity_cls = torch.sum(torch.softmax(torch.tensor(out["affinity_cls_logits"]), -1) * torch.linspace(0, 15, 32),
|
117 |
dim=-1).item()
|
118 |
-
|
119 |
|
120 |
affinity_2d_max = torch.linspace(0, 15, 32)[torch.argmax(torch.tensor(out["affinity_2d_logits"]))].item()
|
121 |
affinity_1d_max = torch.linspace(0, 15, 32)[torch.argmax(torch.tensor(out["affinity_1d_logits"]))].item()
|
122 |
affinity_cls_max = torch.linspace(0, 15, 32)[torch.argmax(torch.tensor(out["affinity_cls_logits"]))].item()
|
123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
print("Affinity: ", affinity_2d, affinity_cls, affinity_1d)
|
125 |
with open(affinity_output_path, "w") as f:
|
126 |
json.dump({"affinity_2d": affinity_2d, "affinity_1d": affinity_1d, "affinity_cls": affinity_cls,
|
127 |
"affinity_2d_max": affinity_2d_max, "affinity_1d_max": affinity_1d_max,
|
128 |
-
"affinity_cls_max": affinity_cls_max
|
|
|
129 |
|
130 |
# binding_site = torch.sigmoid(torch.tensor(out["binding_site_logits"])) * 100
|
131 |
# binding_site = binding_site[:processed_feature_dict["aatype"].shape[1]].flatten()
|
@@ -135,9 +145,6 @@ def run_on_folder(input_dir: str, output_dir: str, run_config_path: str, skip_re
|
|
135 |
|
136 |
ligand_output_path = os.path.join(output_directory, f"{output_name}_ligand_{{i}}.sdf")
|
137 |
|
138 |
-
protein_mask = processed_feature_dict["protein_mask"][0].astype(bool)
|
139 |
-
ligand_mask = processed_feature_dict["ligand_mask"][0].astype(bool)
|
140 |
-
|
141 |
save_output_structure(
|
142 |
aatype=processed_feature_dict["aatype"][0][protein_mask],
|
143 |
residue_index=processed_feature_dict["in_chain_residue_index"][0],
|
|
|
14 |
# limitations under the License.
|
15 |
import sys
|
16 |
|
17 |
+
from env_consts import TEST_INPUT_DIR, TEST_OUTPUT_DIR
|
18 |
import json
|
19 |
import logging
|
20 |
import numpy as np
|
|
|
59 |
return base_config
|
60 |
|
61 |
|
62 |
+
def run_on_folder(input_dir: str, output_dir: str, run_config_path: str, ckpt_path: str, skip_relaxation=True,
|
63 |
long_sequence_inference=False, skip_exists=False):
|
64 |
config_preset = "initial_training"
|
65 |
save_outputs = False
|
|
|
67 |
|
68 |
run_config = json.load(open(run_config_path))
|
69 |
|
70 |
+
ckpt_path = os.path.abspath(ckpt_path)
|
|
|
|
|
71 |
print("Using checkpoint: ", ckpt_path)
|
72 |
|
73 |
config = model_config(config_preset, long_sequence_inference=long_sequence_inference)
|
|
|
113 |
dim=-1).item()
|
114 |
affinity_cls = torch.sum(torch.softmax(torch.tensor(out["affinity_cls_logits"]), -1) * torch.linspace(0, 15, 32),
|
115 |
dim=-1).item()
|
116 |
+
affinity_cls_reg = torch.tensor(out["affinity_cls_reg_logits"]).item()
|
117 |
|
118 |
affinity_2d_max = torch.linspace(0, 15, 32)[torch.argmax(torch.tensor(out["affinity_2d_logits"]))].item()
|
119 |
affinity_1d_max = torch.linspace(0, 15, 32)[torch.argmax(torch.tensor(out["affinity_1d_logits"]))].item()
|
120 |
affinity_cls_max = torch.linspace(0, 15, 32)[torch.argmax(torch.tensor(out["affinity_cls_logits"]))].item()
|
121 |
|
122 |
+
protein_mask = processed_feature_dict["protein_mask"][0].astype(bool)
|
123 |
+
ligand_mask = processed_feature_dict["ligand_mask"][0].astype(bool)
|
124 |
+
|
125 |
+
protein_length = protein_mask.sum()
|
126 |
+
ligand_length = ligand_mask.sum()
|
127 |
+
predicted_inter_contacts_logits = torch.tensor(out["inter_contact_logits"][0][:protein_length,
|
128 |
+
protein_length:protein_length+ligand_length, :])
|
129 |
+
|
130 |
+
top_100_inter_contacts = torch.topk(predicted_inter_contacts_logits.flatten(), 100).indices
|
131 |
+
inter_contacts_indices = [[int(i // ligand_length), int(i % ligand_length)] for i in top_100_inter_contacts]
|
132 |
+
|
133 |
print("Affinity: ", affinity_2d, affinity_cls, affinity_1d)
|
134 |
with open(affinity_output_path, "w") as f:
|
135 |
json.dump({"affinity_2d": affinity_2d, "affinity_1d": affinity_1d, "affinity_cls": affinity_cls,
|
136 |
"affinity_2d_max": affinity_2d_max, "affinity_1d_max": affinity_1d_max,
|
137 |
+
"affinity_cls_max": affinity_cls_max, "affinity_cls_reg": affinity_cls_reg,
|
138 |
+
"inter_contacts": inter_contacts_indices}, f)
|
139 |
|
140 |
# binding_site = torch.sigmoid(torch.tensor(out["binding_site_logits"])) * 100
|
141 |
# binding_site = binding_site[:processed_feature_dict["aatype"].shape[1]].flatten()
|
|
|
145 |
|
146 |
ligand_output_path = os.path.join(output_directory, f"{output_name}_ligand_{{i}}.sdf")
|
147 |
|
|
|
|
|
|
|
148 |
save_output_structure(
|
149 |
aatype=processed_feature_dict["aatype"][0][protein_mask],
|
150 |
residue_index=processed_feature_dict["in_chain_residue_index"][0],
|