seanbenhur commited on
Commit
e1acb50
·
1 Parent(s): 19d436b

Update tamilatis/predict.py

Browse files
Files changed (1) hide show
  1. tamilatis/predict.py +2 -2
tamilatis/predict.py CHANGED
@@ -16,10 +16,10 @@ class TamilATISPredictor:
16
  num_labels,
17
  ):
18
  self.model = model
19
- self.model.load_state_dict(torch.load(checkpoint_path))
20
  self.model.eval()
21
  self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
22
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
23
  self.num_labels = num_labels
24
  self.label_encoder = label_encoder
25
  self.intent_encoder = intent_encoder
 
16
  num_labels,
17
  ):
18
  self.model = model
 
19
  self.model.eval()
20
  self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
21
+ self.device = "cuda" if torch.cuda.is_available() else "CPU"
22
+ self.model.load_state_dict(torch.load(checkpoint_path,map_location=self.device))
23
  self.num_labels = num_labels
24
  self.label_encoder = label_encoder
25
  self.intent_encoder = intent_encoder