nragrawal commited on
Commit
d957efc
·
1 Parent(s): ef4c1d8

Update app to load from checkpoint

Browse files
Files changed (1) hide show
  1. app.py +12 -2
app.py CHANGED
@@ -9,9 +9,19 @@ from network import create_model # Import our model architecture
9
  # Load model from local checkpoint
10
  def load_model():
11
  try:
12
- model = create_model(num_classes=1000) # Match your training classes
13
  checkpoint = torch.load('model/model_best.pth', map_location='cpu')
14
- model.load_state_dict(checkpoint['model_state_dict'])
 
 
 
 
 
 
 
 
 
 
15
  model.eval()
16
  return model
17
  except Exception as e:
 
9
  # Load model from local checkpoint
10
  def load_model():
11
  try:
12
+ model = create_model(num_classes=1000)
13
  checkpoint = torch.load('model/model_best.pth', map_location='cpu')
14
+
15
+ # Handle DataParallel state dict
16
+ state_dict = checkpoint['model_state_dict']
17
+ # Remove 'module.' prefix if it exists
18
+ new_state_dict = {}
19
+ for k, v in state_dict.items():
20
+ name = k.replace('module.', '') # Remove 'module.' prefix
21
+ new_state_dict[name] = v
22
+
23
+ # Load the modified state dict
24
+ model.load_state_dict(new_state_dict)
25
  model.eval()
26
  return model
27
  except Exception as e: