Spaces:
Build error
Build error
jiangjiechen
commited on
Commit
•
f18943d
1
Parent(s):
c0884b8
fix device issue: cpu
Browse files
src/check_client/fact_checker.py
CHANGED
@@ -56,9 +56,9 @@ class FactChecker:
|
|
56 |
self.tokenizer = None
|
57 |
self.model = None
|
58 |
self.args = args
|
|
|
59 |
self.ckpt = args.fc_dir if fc_ckpt_dir is None else fc_ckpt_dir
|
60 |
self.mask_rate = mask_rate
|
61 |
-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
62 |
|
63 |
logger.info('Initializing fact checker.')
|
64 |
self._prepare_ckpt(self.args.model_name_or_path, self.ckpt)
|
@@ -98,7 +98,7 @@ class FactChecker:
|
|
98 |
dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
|
99 |
|
100 |
with torch.no_grad():
|
101 |
-
self.model.to(self.device)
|
102 |
self.model.eval()
|
103 |
iter = tqdm(dataloader, desc="Fact Checking") if verbose else dataloader
|
104 |
_, y_predicted, z_predicted, m_attn, mask = \
|
|
|
56 |
self.tokenizer = None
|
57 |
self.model = None
|
58 |
self.args = args
|
59 |
+
self.args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
60 |
self.ckpt = args.fc_dir if fc_ckpt_dir is None else fc_ckpt_dir
|
61 |
self.mask_rate = mask_rate
|
|
|
62 |
|
63 |
logger.info('Initializing fact checker.')
|
64 |
self._prepare_ckpt(self.args.model_name_or_path, self.ckpt)
|
|
|
98 |
dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
|
99 |
|
100 |
with torch.no_grad():
|
101 |
+
self.model.to(self.args.device)
|
102 |
self.model.eval()
|
103 |
iter = tqdm(dataloader, desc="Fact Checking") if verbose else dataloader
|
104 |
_, y_predicted, z_predicted, m_attn, mask = \
|
src/check_client/train.py
CHANGED
@@ -267,7 +267,7 @@ def do_evaluate(dataloader, model, args, during_training=False, with_label=True)
|
|
267 |
mask = []
|
268 |
for i, batch in enumerate(dataloader):
|
269 |
model.eval()
|
270 |
-
batch = tuple(t.to(
|
271 |
with torch.no_grad():
|
272 |
inputs = {
|
273 |
"claim_input_ids": batch[0],
|
|
|
267 |
mask = []
|
268 |
for i, batch in enumerate(dataloader):
|
269 |
model.eval()
|
270 |
+
batch = tuple(t.to(args.device) for t in batch)
|
271 |
with torch.no_grad():
|
272 |
inputs = {
|
273 |
"claim_input_ids": batch[0],
|