feiyang-cai commited on
Commit
954342d
·
1 Parent(s): edaff0a
Files changed (1) hide show
  1. utils.py +5 -1
utils.py CHANGED
@@ -236,7 +236,11 @@ class ReactionPredictionModel():
236
  for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
237
  with torch.no_grad():
238
  generation_prompts = batch['generation_prompts'][0]
239
- inputs = self.tokenizer(generation_prompts, return_tensors="pt", padding=True, truncation=True).to(self.retro_model.device)
 
 
 
 
240
  del inputs['token_type_ids']
241
  if task_type == "retrosynthesis":
242
  outputs = self.retro_model.generate(**inputs, max_length=512, num_return_sequences=10,
 
236
  for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
237
  with torch.no_grad():
238
  generation_prompts = batch['generation_prompts'][0]
239
+ inputs = self.tokenizer(generation_prompts, return_tensors="pt", padding=True, truncation=True)
240
+ inputs = {k: v.to(self.retro_model.device) for k, v in inputs.items()}
241
+ print(inputs)
242
+ print(self.forward_model.device)
243
+ print(self.retro_model.device)
244
  del inputs['token_type_ids']
245
  if task_type == "retrosynthesis":
246
  outputs = self.retro_model.generate(**inputs, max_length=512, num_return_sequences=10,