jiangjiechen commited on
Commit
c0884b8
·
1 Parent(s): 460c37e

fix device issue: cpu

Browse files
src/check_client/fact_checker.py CHANGED
@@ -58,6 +58,7 @@ class FactChecker:
58
  self.args = args
59
  self.ckpt = args.fc_dir if fc_ckpt_dir is None else fc_ckpt_dir
60
  self.mask_rate = mask_rate
 
61
 
62
  logger.info('Initializing fact checker.')
63
  self._prepare_ckpt(self.args.model_name_or_path, self.ckpt)
@@ -97,7 +98,7 @@ class FactChecker:
97
  dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
98
 
99
  with torch.no_grad():
100
- self.model.to(self.args.device)
101
  self.model.eval()
102
  iter = tqdm(dataloader, desc="Fact Checking") if verbose else dataloader
103
  _, y_predicted, z_predicted, m_attn, mask = \
@@ -192,8 +193,6 @@ if __name__ == '__main__':
192
 
193
  set_seed(args)
194
 
195
- args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
196
-
197
  if args.output == 'none':
198
  args.ckpt = args.ckpt[:-1] if args.ckpt.endswith('/') else args.ckpt
199
  base_name = os.path.basename(args.ckpt)
 
58
  self.args = args
59
  self.ckpt = args.fc_dir if fc_ckpt_dir is None else fc_ckpt_dir
60
  self.mask_rate = mask_rate
61
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
 
63
  logger.info('Initializing fact checker.')
64
  self._prepare_ckpt(self.args.model_name_or_path, self.ckpt)
 
98
  dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
99
 
100
  with torch.no_grad():
101
+ self.model.to(self.device)
102
  self.model.eval()
103
  iter = tqdm(dataloader, desc="Fact Checking") if verbose else dataloader
104
  _, y_predicted, z_predicted, m_attn, mask = \
 
193
 
194
  set_seed(args)
195
 
 
 
196
  if args.output == 'none':
197
  args.ckpt = args.ckpt[:-1] if args.ckpt.endswith('/') else args.ckpt
198
  base_name = os.path.basename(args.ckpt)
src/check_client/train.py CHANGED
@@ -267,7 +267,7 @@ def do_evaluate(dataloader, model, args, during_training=False, with_label=True)
267
  mask = []
268
  for i, batch in enumerate(dataloader):
269
  model.eval()
270
- batch = tuple(t.to(args.device) for t in batch)
271
  with torch.no_grad():
272
  inputs = {
273
  "claim_input_ids": batch[0],
 
267
  mask = []
268
  for i, batch in enumerate(dataloader):
269
  model.eval()
270
+ batch = tuple(t.to(model.device) for t in batch)
271
  with torch.no_grad():
272
  inputs = {
273
  "claim_input_ids": batch[0],