feiyang-cai commited on
Commit
eefdfab
·
1 Parent(s): 4294df8
Files changed (2) hide show
  1. app.py +1 -1
  2. utils.py +0 -5
app.py CHANGED
@@ -53,7 +53,7 @@ def get_description(task_name):
53
  task = task_names_to_tasks[task_name]
54
  return task_descriptions[task]
55
 
56
- #@spaces.GPU(duration=10)
57
  def predict_single_label(smiles, task_name):
58
  task = task_names_to_tasks[task_name]
59
 
 
53
  task = task_names_to_tasks[task_name]
54
  return task_descriptions[task]
55
 
56
+ @spaces.GPU(duration=60)
57
  def predict_single_label(smiles, task_name):
58
  task = task_names_to_tasks[task_name]
59
 
utils.py CHANGED
@@ -209,7 +209,6 @@ class ReactionPredictionModel():
209
  self.forward_model.to("cuda")
210
  self.forward_model.eval()
211
 
212
- @spaces.GPU(duration=60)
213
  def predict(self, test_loader, task_type):
214
  predictions = []
215
  for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
@@ -219,8 +218,6 @@ class ReactionPredictionModel():
219
  del inputs['token_type_ids']
220
 
221
  if task_type == "retrosynthesis":
222
- self.retro_model.to("cuda")
223
- self.retro_model.eval()
224
  inputs = {k: v.to(self.retro_model.device) for k, v in inputs.items()}
225
  print(inputs)
226
  print(self.retro_model.device)
@@ -233,8 +230,6 @@ class ReactionPredictionModel():
233
  length_penalty=0.0,
234
  )
235
  else:
236
- self.forward_model.to("cuda")
237
- self.forward_model.eval()
238
  inputs = {k: v.to(self.forward_model.device) for k, v in inputs.items()}
239
  print(inputs)
240
  print(self.forward_model.device)
 
209
  self.forward_model.to("cuda")
210
  self.forward_model.eval()
211
 
 
212
  def predict(self, test_loader, task_type):
213
  predictions = []
214
  for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
 
218
  del inputs['token_type_ids']
219
 
220
  if task_type == "retrosynthesis":
 
 
221
  inputs = {k: v.to(self.retro_model.device) for k, v in inputs.items()}
222
  print(inputs)
223
  print(self.retro_model.device)
 
230
  length_penalty=0.0,
231
  )
232
  else:
 
 
233
  inputs = {k: v.to(self.forward_model.device) for k, v in inputs.items()}
234
  print(inputs)
235
  print(self.forward_model.device)