Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
eefdfab
1
Parent(s):
4294df8
update
Browse files
app.py
CHANGED
@@ -53,7 +53,7 @@ def get_description(task_name):
|
|
53 |
task = task_names_to_tasks[task_name]
|
54 |
return task_descriptions[task]
|
55 |
|
56 |
-
|
57 |
def predict_single_label(smiles, task_name):
|
58 |
task = task_names_to_tasks[task_name]
|
59 |
|
|
|
53 |
task = task_names_to_tasks[task_name]
|
54 |
return task_descriptions[task]
|
55 |
|
56 |
+
@spaces.GPU(duration=60)
|
57 |
def predict_single_label(smiles, task_name):
|
58 |
task = task_names_to_tasks[task_name]
|
59 |
|
utils.py
CHANGED
@@ -209,7 +209,6 @@ class ReactionPredictionModel():
|
|
209 |
self.forward_model.to("cuda")
|
210 |
self.forward_model.eval()
|
211 |
|
212 |
-
@spaces.GPU(duration=60)
|
213 |
def predict(self, test_loader, task_type):
|
214 |
predictions = []
|
215 |
for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
|
@@ -219,8 +218,6 @@ class ReactionPredictionModel():
|
|
219 |
del inputs['token_type_ids']
|
220 |
|
221 |
if task_type == "retrosynthesis":
|
222 |
-
self.retro_model.to("cuda")
|
223 |
-
self.retro_model.eval()
|
224 |
inputs = {k: v.to(self.retro_model.device) for k, v in inputs.items()}
|
225 |
print(inputs)
|
226 |
print(self.retro_model.device)
|
@@ -233,8 +230,6 @@ class ReactionPredictionModel():
|
|
233 |
length_penalty=0.0,
|
234 |
)
|
235 |
else:
|
236 |
-
self.forward_model.to("cuda")
|
237 |
-
self.forward_model.eval()
|
238 |
inputs = {k: v.to(self.forward_model.device) for k, v in inputs.items()}
|
239 |
print(inputs)
|
240 |
print(self.forward_model.device)
|
|
|
209 |
self.forward_model.to("cuda")
|
210 |
self.forward_model.eval()
|
211 |
|
|
|
212 |
def predict(self, test_loader, task_type):
|
213 |
predictions = []
|
214 |
for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
|
|
|
218 |
del inputs['token_type_ids']
|
219 |
|
220 |
if task_type == "retrosynthesis":
|
|
|
|
|
221 |
inputs = {k: v.to(self.retro_model.device) for k, v in inputs.items()}
|
222 |
print(inputs)
|
223 |
print(self.retro_model.device)
|
|
|
230 |
length_penalty=0.0,
|
231 |
)
|
232 |
else:
|
|
|
|
|
233 |
inputs = {k: v.to(self.forward_model.device) for k, v in inputs.items()}
|
234 |
print(inputs)
|
235 |
print(self.forward_model.device)
|