DawnC commited on
Commit
ffa42ba
·
verified ·
1 Parent(s): 403f8f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -2
app.py CHANGED
@@ -82,12 +82,22 @@ class ModelManager:
82
  num_classes=len(dog_breeds),
83
  device=self.device
84
  ).to(self.device)
85
-
86
  checkpoint = torch.load(
87
  'ConvNextV2Base_best_model.pth',
88
  map_location=self.device
89
  )
90
- self._breed_model.load_state_dict(checkpoint['model_state_dict'], strict=False)
 
 
 
 
 
 
 
 
 
 
91
  self._breed_model.eval()
92
  return self._breed_model
93
 
 
82
  num_classes=len(dog_breeds),
83
  device=self.device
84
  ).to(self.device)
85
+
86
  checkpoint = torch.load(
87
  'ConvNextV2Base_best_model.pth',
88
  map_location=self.device
89
  )
90
+
91
+ # Try to load with model_state_dict first, then base_model
92
+ if 'model_state_dict' in checkpoint:
93
+ self._breed_model.load_state_dict(checkpoint['model_state_dict'], strict=False)
94
+ elif 'base_model' in checkpoint:
95
+ self._breed_model.load_state_dict(checkpoint['base_model'], strict=False)
96
+ else:
97
+ # If neither key exists, raise a descriptive error
98
+ available_keys = list(checkpoint.keys()) if isinstance(checkpoint, dict) else "not a dictionary"
99
+ raise KeyError(f"Model checkpoint does not contain 'model_state_dict' or 'base_model' keys. Available keys: {available_keys}")
100
+
101
  self._breed_model.eval()
102
  return self._breed_model
103