Spaces:
Running
on
Zero
Running
on
Zero
Update utils.py
Browse files
utils.py
CHANGED
@@ -129,7 +129,8 @@ class ReactionPredictionModel():
|
|
129 |
token = os.environ.get("TOKEN")
|
130 |
)
|
131 |
self.load_forward_model(candidate_models[model])
|
132 |
-
|
|
|
133 |
string_template_path = hf_hub_download(candidate_models[list(candidate_models.keys())[0]], filename="string_template.json", token = os.environ.get("TOKEN"))
|
134 |
string_template = json.load(open(string_template_path, 'r'))
|
135 |
reactant_start_str = string_template['REACTANTS_START_STRING']
|
@@ -205,33 +206,9 @@ class ReactionPredictionModel():
|
|
205 |
)
|
206 |
self.forward_model.config.pad_token_id = self.tokenizer.pad_token_id
|
207 |
self.forward_model.to("cuda")
|
208 |
-
|
209 |
-
@spaces.GPU(duration=20)
|
210 |
-
def predict_single_smiles(self, smiles, task_type):
|
211 |
-
if task_type == "full_retro":
|
212 |
-
if "." in smiles:
|
213 |
-
return None
|
214 |
-
|
215 |
-
task_type = "retrosynthesis" if task_type == "full_retro" else "synthesis"
|
216 |
-
# canonicalize the smiles
|
217 |
-
mol = Chem.MolFromSmiles(smiles)
|
218 |
-
if mol is None:
|
219 |
-
return None
|
220 |
-
smiles = Chem.MolToSmiles(mol)
|
221 |
-
|
222 |
-
smiles_list = [smiles]
|
223 |
-
task_type_list = [task_type]
|
224 |
-
|
225 |
-
|
226 |
-
df = pd.DataFrame({"src": smiles_list, "task_type": task_type_list})
|
227 |
-
test_dataset = Dataset.from_pandas(df)
|
228 |
-
# construct the dataloader
|
229 |
-
test_loader = torch.utils.data.DataLoader(
|
230 |
-
test_dataset,
|
231 |
-
batch_size=1,
|
232 |
-
collate_fn=self.data_collator,
|
233 |
-
)
|
234 |
|
|
|
|
|
235 |
predictions = []
|
236 |
for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
|
237 |
with torch.no_grad():
|
@@ -276,6 +253,36 @@ class ReactionPredictionModel():
|
|
276 |
predictions.append(canonized_smiles_list)
|
277 |
|
278 |
rank, invalid_rate = compute_rank(predictions)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
return rank
|
280 |
|
281 |
|
|
|
129 |
token = os.environ.get("TOKEN")
|
130 |
)
|
131 |
self.load_forward_model(candidate_models[model])
|
132 |
+
|
133 |
+
print(self.forward_model.device, self.retro_model.device)
|
134 |
string_template_path = hf_hub_download(candidate_models[list(candidate_models.keys())[0]], filename="string_template.json", token = os.environ.get("TOKEN"))
|
135 |
string_template = json.load(open(string_template_path, 'r'))
|
136 |
reactant_start_str = string_template['REACTANTS_START_STRING']
|
|
|
206 |
)
|
207 |
self.forward_model.config.pad_token_id = self.tokenizer.pad_token_id
|
208 |
self.forward_model.to("cuda")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
+
@spaces.GPU(duration=20)
|
211 |
+
def predict(self, test_loader):
|
212 |
predictions = []
|
213 |
for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
|
214 |
with torch.no_grad():
|
|
|
253 |
predictions.append(canonized_smiles_list)
|
254 |
|
255 |
rank, invalid_rate = compute_rank(predictions)
|
256 |
+
|
257 |
+
return rank
|
258 |
+
|
259 |
+
def predict_single_smiles(self, smiles, task_type):
|
260 |
+
if task_type == "full_retro":
|
261 |
+
if "." in smiles:
|
262 |
+
return None
|
263 |
+
|
264 |
+
task_type = "retrosynthesis" if task_type == "full_retro" else "synthesis"
|
265 |
+
# canonicalize the smiles
|
266 |
+
mol = Chem.MolFromSmiles(smiles)
|
267 |
+
if mol is None:
|
268 |
+
return None
|
269 |
+
smiles = Chem.MolToSmiles(mol)
|
270 |
+
|
271 |
+
smiles_list = [smiles]
|
272 |
+
task_type_list = [task_type]
|
273 |
+
|
274 |
+
|
275 |
+
df = pd.DataFrame({"src": smiles_list, "task_type": task_type_list})
|
276 |
+
test_dataset = Dataset.from_pandas(df)
|
277 |
+
# construct the dataloader
|
278 |
+
test_loader = torch.utils.data.DataLoader(
|
279 |
+
test_dataset,
|
280 |
+
batch_size=1,
|
281 |
+
collate_fn=self.data_collator,
|
282 |
+
)
|
283 |
+
|
284 |
+
rank = self.predict(test_loader)
|
285 |
+
|
286 |
return rank
|
287 |
|
288 |
|