Update utils.py
Browse files
utils.py
CHANGED
@@ -76,7 +76,7 @@ def model_pred_tf(model_path, img, class_names=classes):
|
|
76 |
def get_model_pt(model_path):
|
77 |
model = timm.create_model('vit_base_patch16_224', pretrained=False)
|
78 |
model.head = nn.Linear(in_features=768, out_features=len(classes), bias=True)
|
79 |
-
model.load_state_dict(torch.load(
|
80 |
return model
|
81 |
|
82 |
def load_prepare_image_pt(input_image):
|
|
|
76 |
def get_model_pt(model_path):
|
77 |
model = timm.create_model('vit_base_patch16_224', pretrained=False)
|
78 |
model.head = nn.Linear(in_features=768, out_features=len(classes), bias=True)
|
79 |
+
model.load_state_dict(torch.load(model_path, map_location='cpu'))
|
80 |
return model
|
81 |
|
82 |
def load_prepare_image_pt(input_image):
|