Spaces:
Build error
Build error
jiangjiechen
commited on
Commit
•
817b9ee
1
Parent(s):
0d63b55
fix device issue: cpu
Browse files
src/er_client/sentence_selection.py
CHANGED
@@ -23,7 +23,8 @@ class SentSelector:
|
|
23 |
self.bert_model = BertForSequenceEncoder.from_pretrained(pretrained_bert_path)
|
24 |
|
25 |
self.rank_model = inference_model(self.bert_model, self.args)
|
26 |
-
self.rank_model.load_state_dict(torch.load(select_model_path
|
|
|
27 |
|
28 |
if self.use_cuda:
|
29 |
self.bert_model = self.bert_model.cuda()
|
|
|
23 |
self.bert_model = BertForSequenceEncoder.from_pretrained(pretrained_bert_path)
|
24 |
|
25 |
self.rank_model = inference_model(self.bert_model, self.args)
|
26 |
+
self.rank_model.load_state_dict(torch.load(select_model_path,
|
27 |
+
map_location=None if self.use_cuda else torch.device('cpu'))['model'])
|
28 |
|
29 |
if self.use_cuda:
|
30 |
self.bert_model = self.bert_model.cuda()
|