jiangjiechen commited on
Commit
f18943d
1 Parent(s): c0884b8

fix device issue: cpu

Browse files
src/check_client/fact_checker.py CHANGED
@@ -56,9 +56,9 @@ class FactChecker:
56
  self.tokenizer = None
57
  self.model = None
58
  self.args = args
 
59
  self.ckpt = args.fc_dir if fc_ckpt_dir is None else fc_ckpt_dir
60
  self.mask_rate = mask_rate
61
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
 
63
  logger.info('Initializing fact checker.')
64
  self._prepare_ckpt(self.args.model_name_or_path, self.ckpt)
@@ -98,7 +98,7 @@ class FactChecker:
98
  dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
99
 
100
  with torch.no_grad():
101
- self.model.to(self.device)
102
  self.model.eval()
103
  iter = tqdm(dataloader, desc="Fact Checking") if verbose else dataloader
104
  _, y_predicted, z_predicted, m_attn, mask = \
 
56
  self.tokenizer = None
57
  self.model = None
58
  self.args = args
59
+ self.args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
  self.ckpt = args.fc_dir if fc_ckpt_dir is None else fc_ckpt_dir
61
  self.mask_rate = mask_rate
 
62
 
63
  logger.info('Initializing fact checker.')
64
  self._prepare_ckpt(self.args.model_name_or_path, self.ckpt)
 
98
  dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
99
 
100
  with torch.no_grad():
101
+ self.model.to(self.args.device)
102
  self.model.eval()
103
  iter = tqdm(dataloader, desc="Fact Checking") if verbose else dataloader
104
  _, y_predicted, z_predicted, m_attn, mask = \
src/check_client/train.py CHANGED
@@ -267,7 +267,7 @@ def do_evaluate(dataloader, model, args, during_training=False, with_label=True)
267
  mask = []
268
  for i, batch in enumerate(dataloader):
269
  model.eval()
270
- batch = tuple(t.to(model.device) for t in batch)
271
  with torch.no_grad():
272
  inputs = {
273
  "claim_input_ids": batch[0],
 
267
  mask = []
268
  for i, batch in enumerate(dataloader):
269
  model.eval()
270
+ batch = tuple(t.to(args.device) for t in batch)
271
  with torch.no_grad():
272
  inputs = {
273
  "claim_input_ids": batch[0],