jiangjiechen commited on
Commit
817b9ee
1 Parent(s): 0d63b55

fix device issue: cpu

Browse files
src/er_client/sentence_selection.py CHANGED
@@ -23,7 +23,8 @@ class SentSelector:
23
  self.bert_model = BertForSequenceEncoder.from_pretrained(pretrained_bert_path)
24
 
25
  self.rank_model = inference_model(self.bert_model, self.args)
26
- self.rank_model.load_state_dict(torch.load(select_model_path)['model'])
 
27
 
28
  if self.use_cuda:
29
  self.bert_model = self.bert_model.cuda()
 
23
  self.bert_model = BertForSequenceEncoder.from_pretrained(pretrained_bert_path)
24
 
25
  self.rank_model = inference_model(self.bert_model, self.args)
26
+ self.rank_model.load_state_dict(torch.load(select_model_path,
27
+ map_location=None if self.use_cuda else torch.device('cpu'))['model'])
28
 
29
  if self.use_cuda:
30
  self.bert_model = self.bert_model.cuda()