feiyang-cai commited on
Commit
8a703f0
·
verified ·
1 Parent(s): 423e3eb

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +4 -1
utils.py CHANGED
@@ -191,7 +191,7 @@ class MolecularPropertyPredictionModel():
191
  self.base_model = AutoModelForSequenceClassification.from_pretrained(
192
  "ChemFM/ChemFM-3B",
193
  config=config,
194
- device_map="cpu",
195
  trust_remote_code=True,
196
  token = os.environ.get("TOKEN")
197
  )
@@ -284,7 +284,10 @@ class MolecularPropertyPredictionModel():
284
  for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
285
  with torch.no_grad():
286
  batch = {k: v.to(self.base_model.device) for k, v in batch.items()}
 
 
287
  outputs = self.base_model(**batch)
 
288
  if task_type == "regression": # TODO: check if the model is regression or classification
289
  y_pred.append(outputs.logits.cpu().detach().numpy())
290
  else:
 
191
  self.base_model = AutoModelForSequenceClassification.from_pretrained(
192
  "ChemFM/ChemFM-3B",
193
  config=config,
194
+ device_map="cuda",
195
  trust_remote_code=True,
196
  token = os.environ.get("TOKEN")
197
  )
 
284
  for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
285
  with torch.no_grad():
286
  batch = {k: v.to(self.base_model.device) for k, v in batch.items()}
287
+ print(self.base_model.device)
288
+ print(batch)
289
  outputs = self.base_model(**batch)
290
+ print(output)
291
  if task_type == "regression": # TODO: check if the model is regression or classification
292
  y_pred.append(outputs.logits.cpu().detach().numpy())
293
  else: