feiyang-cai commited on
Commit
cd388a2
·
verified ·
1 Parent(s): 11b6a1c

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +34 -27
utils.py CHANGED
@@ -129,7 +129,8 @@ class ReactionPredictionModel():
129
  token = os.environ.get("TOKEN")
130
  )
131
  self.load_forward_model(candidate_models[model])
132
-
 
133
  string_template_path = hf_hub_download(candidate_models[list(candidate_models.keys())[0]], filename="string_template.json", token = os.environ.get("TOKEN"))
134
  string_template = json.load(open(string_template_path, 'r'))
135
  reactant_start_str = string_template['REACTANTS_START_STRING']
@@ -205,33 +206,9 @@ class ReactionPredictionModel():
205
  )
206
  self.forward_model.config.pad_token_id = self.tokenizer.pad_token_id
207
  self.forward_model.to("cuda")
208
-
209
- @spaces.GPU(duration=20)
210
- def predict_single_smiles(self, smiles, task_type):
211
- if task_type == "full_retro":
212
- if "." in smiles:
213
- return None
214
-
215
- task_type = "retrosynthesis" if task_type == "full_retro" else "synthesis"
216
- # canonicalize the smiles
217
- mol = Chem.MolFromSmiles(smiles)
218
- if mol is None:
219
- return None
220
- smiles = Chem.MolToSmiles(mol)
221
-
222
- smiles_list = [smiles]
223
- task_type_list = [task_type]
224
-
225
-
226
- df = pd.DataFrame({"src": smiles_list, "task_type": task_type_list})
227
- test_dataset = Dataset.from_pandas(df)
228
- # construct the dataloader
229
- test_loader = torch.utils.data.DataLoader(
230
- test_dataset,
231
- batch_size=1,
232
- collate_fn=self.data_collator,
233
- )
234
 
 
 
235
  predictions = []
236
  for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
237
  with torch.no_grad():
@@ -276,6 +253,36 @@ class ReactionPredictionModel():
276
  predictions.append(canonized_smiles_list)
277
 
278
  rank, invalid_rate = compute_rank(predictions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  return rank
280
 
281
 
 
129
  token = os.environ.get("TOKEN")
130
  )
131
  self.load_forward_model(candidate_models[model])
132
+
133
+ print(self.forward_model.device, self.retro_model.device)
134
  string_template_path = hf_hub_download(candidate_models[list(candidate_models.keys())[0]], filename="string_template.json", token = os.environ.get("TOKEN"))
135
  string_template = json.load(open(string_template_path, 'r'))
136
  reactant_start_str = string_template['REACTANTS_START_STRING']
 
206
  )
207
  self.forward_model.config.pad_token_id = self.tokenizer.pad_token_id
208
  self.forward_model.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
+ @spaces.GPU(duration=20)
211
+ def predict(self, test_loader):
212
  predictions = []
213
  for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
214
  with torch.no_grad():
 
253
  predictions.append(canonized_smiles_list)
254
 
255
  rank, invalid_rate = compute_rank(predictions)
256
+
257
+ return rank
258
+
259
+ def predict_single_smiles(self, smiles, task_type):
260
+ if task_type == "full_retro":
261
+ if "." in smiles:
262
+ return None
263
+
264
+ task_type = "retrosynthesis" if task_type == "full_retro" else "synthesis"
265
+ # canonicalize the smiles
266
+ mol = Chem.MolFromSmiles(smiles)
267
+ if mol is None:
268
+ return None
269
+ smiles = Chem.MolToSmiles(mol)
270
+
271
+ smiles_list = [smiles]
272
+ task_type_list = [task_type]
273
+
274
+
275
+ df = pd.DataFrame({"src": smiles_list, "task_type": task_type_list})
276
+ test_dataset = Dataset.from_pandas(df)
277
+ # construct the dataloader
278
+ test_loader = torch.utils.data.DataLoader(
279
+ test_dataset,
280
+ batch_size=1,
281
+ collate_fn=self.data_collator,
282
+ )
283
+
284
+ rank = self.predict(test_loader)
285
+
286
  return rank
287
 
288