Spaces:
Sleeping
Sleeping
Commit
·
e1acb50
1
Parent(s):
19d436b
Update tamilatis/predict.py
Browse files- tamilatis/predict.py +2 -2
tamilatis/predict.py
CHANGED
@@ -16,10 +16,10 @@ class TamilATISPredictor:
|
|
16 |
num_labels,
|
17 |
):
|
18 |
self.model = model
|
19 |
-
self.model.load_state_dict(torch.load(checkpoint_path))
|
20 |
self.model.eval()
|
21 |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
22 |
-
self.device = "cuda" if torch.cuda.is_available() else "
|
|
|
23 |
self.num_labels = num_labels
|
24 |
self.label_encoder = label_encoder
|
25 |
self.intent_encoder = intent_encoder
|
|
|
16 |
num_labels,
|
17 |
):
|
18 |
self.model = model
|
|
|
19 |
self.model.eval()
|
20 |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
21 |
+
self.device = "cuda" if torch.cuda.is_available() else "CPU"
|
22 |
+
self.model.load_state_dict(torch.load(checkpoint_path,map_location=self.device))
|
23 |
self.num_labels = num_labels
|
24 |
self.label_encoder = label_encoder
|
25 |
self.intent_encoder = intent_encoder
|