Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -82,12 +82,22 @@ class ModelManager:
|
|
82 |
num_classes=len(dog_breeds),
|
83 |
device=self.device
|
84 |
).to(self.device)
|
85 |
-
|
86 |
checkpoint = torch.load(
|
87 |
'ConvNextV2Base_best_model.pth',
|
88 |
map_location=self.device
|
89 |
)
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
self._breed_model.eval()
|
92 |
return self._breed_model
|
93 |
|
|
|
82 |
num_classes=len(dog_breeds),
|
83 |
device=self.device
|
84 |
).to(self.device)
|
85 |
+
|
86 |
checkpoint = torch.load(
|
87 |
'ConvNextV2Base_best_model.pth',
|
88 |
map_location=self.device
|
89 |
)
|
90 |
+
|
91 |
+
# Try to load with model_state_dict first, then base_model
|
92 |
+
if 'model_state_dict' in checkpoint:
|
93 |
+
self._breed_model.load_state_dict(checkpoint['model_state_dict'], strict=False)
|
94 |
+
elif 'base_model' in checkpoint:
|
95 |
+
self._breed_model.load_state_dict(checkpoint['base_model'], strict=False)
|
96 |
+
else:
|
97 |
+
# If neither key exists, raise a descriptive error
|
98 |
+
available_keys = list(checkpoint.keys()) if isinstance(checkpoint, dict) else "not a dictionary"
|
99 |
+
raise KeyError(f"Model checkpoint does not contain 'model_state_dict' or 'base_model' keys. Available keys: {available_keys}")
|
100 |
+
|
101 |
self._breed_model.eval()
|
102 |
return self._breed_model
|
103 |
|