fix: testing cpu inference
Browse files
script.py
CHANGED
@@ -26,11 +26,12 @@ class PytorchWorker:
|
|
26 |
|
27 |
model = timm.create_model(model_name, num_classes=number_of_categories, pretrained=False)
|
28 |
|
29 |
-
if not torch.cuda.is_available():
|
30 |
-
|
31 |
-
else:
|
32 |
-
|
33 |
|
|
|
34 |
model.load_state_dict(model_ckpt)
|
35 |
|
36 |
return model.to(device).eval()
|
|
|
26 |
|
27 |
model = timm.create_model(model_name, num_classes=number_of_categories, pretrained=False)
|
28 |
|
29 |
+
# if not torch.cuda.is_available():
|
30 |
+
# model_ckpt = torch.load(model_path, map_location=torch.device("cpu"))
|
31 |
+
# else:
|
32 |
+
# model_ckpt = torch.load(model_path)
|
33 |
|
34 |
+
model_ckpt = torch.load(model_path, map_location=torch.device("cpu"))
|
35 |
model.load_state_dict(model_ckpt)
|
36 |
|
37 |
return model.to(device).eval()
|