Spaces:
Sleeping
Sleeping
Update utils.py
Browse files
utils.py
CHANGED
@@ -180,6 +180,7 @@ class ReactionPredictionModel():
|
|
180 |
)
|
181 |
|
182 |
self.retro_model.to("cuda")
|
|
|
183 |
|
184 |
def load_forward_model(self, model_path):
|
185 |
config = AutoConfig.from_pretrained(
|
@@ -206,21 +207,24 @@ class ReactionPredictionModel():
|
|
206 |
)
|
207 |
self.forward_model.config.pad_token_id = self.tokenizer.pad_token_id
|
208 |
self.forward_model.to("cuda")
|
|
|
209 |
|
210 |
@spaces.GPU(duration=30)
|
211 |
def predict(self, test_loader, task_type):
|
212 |
predictions = []
|
213 |
for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
|
|
|
|
224 |
outputs = self.retro_model.generate(**inputs, max_length=512, num_return_sequences=10,
|
225 |
do_sample=False, num_beams=10,
|
226 |
eos_token_id=self.tokenizer.eos_token_id,
|
@@ -228,11 +232,13 @@ class ReactionPredictionModel():
|
|
228 |
pad_token_id=self.tokenizer.pad_token_id,
|
229 |
length_penalty=0.0,
|
230 |
)
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
|
|
|
|
236 |
outputs = self.forward_model.generate(**inputs, max_length=512, num_return_sequences=10,
|
237 |
do_sample=False, num_beams=10,
|
238 |
eos_token_id=self.tokenizer.eos_token_id,
|
@@ -241,22 +247,22 @@ class ReactionPredictionModel():
|
|
241 |
length_penalty=0.0,
|
242 |
)
|
243 |
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
|
261 |
rank, invalid_rate = compute_rank(predictions)
|
262 |
print(predictions, rank)
|
|
|
180 |
)
|
181 |
|
182 |
self.retro_model.to("cuda")
|
183 |
+
self.retro_model.eval()
|
184 |
|
185 |
def load_forward_model(self, model_path):
|
186 |
config = AutoConfig.from_pretrained(
|
|
|
207 |
)
|
208 |
self.forward_model.config.pad_token_id = self.tokenizer.pad_token_id
|
209 |
self.forward_model.to("cuda")
|
210 |
+
self.forward_model.eval()
|
211 |
|
212 |
@spaces.GPU(duration=30)
|
213 |
def predict(self, test_loader, task_type):
|
214 |
predictions = []
|
215 |
for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
|
216 |
+
|
217 |
+
generation_prompts = batch['generation_prompts'][0]
|
218 |
+
inputs = self.tokenizer(generation_prompts, return_tensors="pt", padding=True, truncation=True)
|
219 |
+
del inputs['token_type_ids']
|
220 |
+
|
221 |
+
if task_type == "retrosynthesis":
|
222 |
+
self.retro_model.to("cuda")
|
223 |
+
self.retro_model.eval()
|
224 |
+
inputs = {k: v.to(self.retro_model.device) for k, v in inputs.items()}
|
225 |
+
print(inputs)
|
226 |
+
print(self.retro_model.device)
|
227 |
+
with torch.no_grad():
|
228 |
outputs = self.retro_model.generate(**inputs, max_length=512, num_return_sequences=10,
|
229 |
do_sample=False, num_beams=10,
|
230 |
eos_token_id=self.tokenizer.eos_token_id,
|
|
|
232 |
pad_token_id=self.tokenizer.pad_token_id,
|
233 |
length_penalty=0.0,
|
234 |
)
|
235 |
+
else:
|
236 |
+
self.forward_model.to("cuda")
|
237 |
+
self.forward_model.eval()
|
238 |
+
inputs = {k: v.to(self.forward_model.device) for k, v in inputs.items()}
|
239 |
+
print(inputs)
|
240 |
+
print(self.forward_model.device)
|
241 |
+
with torch.no_grad():
|
242 |
outputs = self.forward_model.generate(**inputs, max_length=512, num_return_sequences=10,
|
243 |
do_sample=False, num_beams=10,
|
244 |
eos_token_id=self.tokenizer.eos_token_id,
|
|
|
247 |
length_penalty=0.0,
|
248 |
)
|
249 |
|
250 |
+
print(outputs)
|
251 |
+
original_smiles_list = self.tokenizer.batch_decode(outputs[:, len(inputs['input_ids'][0]):],
|
252 |
+
skip_special_tokens=True)
|
253 |
+
original_smiles_list = map(lambda x: x.replace(" ", ""), original_smiles_list)
|
254 |
+
# canonize the SMILES
|
255 |
+
canonized_smiles_list = []
|
256 |
+
temp = []
|
257 |
+
for original_smiles in original_smiles_list:
|
258 |
+
temp.append(original_smiles)
|
259 |
+
try:
|
260 |
+
canonized_smiles_list.append(Chem.MolToSmiles(Chem.MolFromSmiles(original_smiles)))
|
261 |
+
except:
|
262 |
+
canonized_smiles_list.append("")
|
263 |
+
#canonized_smiles_list = \
|
264 |
+
#['N#Cc1ccsc1Nc1cc(F)c(F)cc1[N+](=O)[O-]', 'N#Cc1ccsc1Nc1cc(F)c([N+](=O)[O-])cc1F', 'N#Cc1ccsc1Nc1cc(Cl)c(F)cc1[N+](=O)[O-]', 'N#Cc1cnsc1Nc1cc(F)c(F)cc1[N+](=O)[O-]', 'N#Cc1cc(F)c(F)cc1Nc1sccc1C#N', 'N#Cc1ccsc1Nc1cc(F)c(F)cc1[N+](=N)[O-]', 'N#Cc1cc(C#N)c(Nc2cc(F)c(F)cc2[N+](=O)[O-])s1', 'N#Cc1ccsc1Nc1c(F)c(F)cc(F)c1[N+](=O)[O-]', 'Nc1sccc1CNc1cc(F)c(F)cc1[N+](=O)[O-]', 'N#Cc1ccsc1Nc1ccc(F)cc1[N+](=O)[O-]']
|
265 |
+
predictions.append(canonized_smiles_list)
|
266 |
|
267 |
rank, invalid_rate = compute_rank(predictions)
|
268 |
print(predictions, rank)
|