Spaces:
Build error
Build error
jiangjiechen
commited on
Commit
•
c0884b8
1
Parent(s):
460c37e
fix device issue: cpu
Browse files
src/check_client/fact_checker.py
CHANGED
@@ -58,6 +58,7 @@ class FactChecker:
|
|
58 |
self.args = args
|
59 |
self.ckpt = args.fc_dir if fc_ckpt_dir is None else fc_ckpt_dir
|
60 |
self.mask_rate = mask_rate
|
|
|
61 |
|
62 |
logger.info('Initializing fact checker.')
|
63 |
self._prepare_ckpt(self.args.model_name_or_path, self.ckpt)
|
@@ -97,7 +98,7 @@ class FactChecker:
|
|
97 |
dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
|
98 |
|
99 |
with torch.no_grad():
|
100 |
-
self.model.to(self.
|
101 |
self.model.eval()
|
102 |
iter = tqdm(dataloader, desc="Fact Checking") if verbose else dataloader
|
103 |
_, y_predicted, z_predicted, m_attn, mask = \
|
@@ -192,8 +193,6 @@ if __name__ == '__main__':
|
|
192 |
|
193 |
set_seed(args)
|
194 |
|
195 |
-
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
196 |
-
|
197 |
if args.output == 'none':
|
198 |
args.ckpt = args.ckpt[:-1] if args.ckpt.endswith('/') else args.ckpt
|
199 |
base_name = os.path.basename(args.ckpt)
|
|
|
58 |
self.args = args
|
59 |
self.ckpt = args.fc_dir if fc_ckpt_dir is None else fc_ckpt_dir
|
60 |
self.mask_rate = mask_rate
|
61 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
62 |
|
63 |
logger.info('Initializing fact checker.')
|
64 |
self._prepare_ckpt(self.args.model_name_or_path, self.ckpt)
|
|
|
98 |
dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
|
99 |
|
100 |
with torch.no_grad():
|
101 |
+
self.model.to(self.device)
|
102 |
self.model.eval()
|
103 |
iter = tqdm(dataloader, desc="Fact Checking") if verbose else dataloader
|
104 |
_, y_predicted, z_predicted, m_attn, mask = \
|
|
|
193 |
|
194 |
set_seed(args)
|
195 |
|
|
|
|
|
196 |
if args.output == 'none':
|
197 |
args.ckpt = args.ckpt[:-1] if args.ckpt.endswith('/') else args.ckpt
|
198 |
base_name = os.path.basename(args.ckpt)
|
src/check_client/train.py
CHANGED
@@ -267,7 +267,7 @@ def do_evaluate(dataloader, model, args, during_training=False, with_label=True)
|
|
267 |
mask = []
|
268 |
for i, batch in enumerate(dataloader):
|
269 |
model.eval()
|
270 |
-
batch = tuple(t.to(
|
271 |
with torch.no_grad():
|
272 |
inputs = {
|
273 |
"claim_input_ids": batch[0],
|
|
|
267 |
mask = []
|
268 |
for i, batch in enumerate(dataloader):
|
269 |
model.eval()
|
270 |
+
batch = tuple(t.to(model.device) for t in batch)
|
271 |
with torch.no_grad():
|
272 |
inputs = {
|
273 |
"claim_input_ids": batch[0],
|