Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
954342d
1
Parent(s):
edaff0a
debug
Browse files
utils.py
CHANGED
@@ -236,7 +236,11 @@ class ReactionPredictionModel():
|
|
236 |
for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
|
237 |
with torch.no_grad():
|
238 |
generation_prompts = batch['generation_prompts'][0]
|
239 |
-
inputs = self.tokenizer(generation_prompts, return_tensors="pt", padding=True, truncation=True)
|
|
|
|
|
|
|
|
|
240 |
del inputs['token_type_ids']
|
241 |
if task_type == "retrosynthesis":
|
242 |
outputs = self.retro_model.generate(**inputs, max_length=512, num_return_sequences=10,
|
|
|
236 |
for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
|
237 |
with torch.no_grad():
|
238 |
generation_prompts = batch['generation_prompts'][0]
|
239 |
+
inputs = self.tokenizer(generation_prompts, return_tensors="pt", padding=True, truncation=True)
|
240 |
+
inputs = {k: v.to(self.retro_model.device) for k, v in inputs.items()}
|
241 |
+
print(inputs)
|
242 |
+
print(self.forward_model.device)
|
243 |
+
print(self.retro_model.device)
|
244 |
del inputs['token_type_ids']
|
245 |
if task_type == "retrosynthesis":
|
246 |
outputs = self.retro_model.generate(**inputs, max_length=512, num_return_sequences=10,
|