anjikum commited on
Commit
1b084ff
·
verified ·
1 Parent(s): 22a11e8

added changes to app to fix model loading errors

Browse files
Files changed (1) hide show
  1. app.py +16 -1
app.py CHANGED
@@ -33,7 +33,22 @@ model.fc = nn.Linear(model.fc.in_features, 1000)
33
  model.load_state_dict(torch.load("model.pth", map_location=device)) # Load the trained weights (.pth)
34
  model.to(device) # Move model to CPU (even if you have a GPU)
35
 
36
- model.eval() # Set model to evaluation mode
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # Define the transformation required for the input image
39
  transform = transforms.Compose([
 
33
  model.load_state_dict(torch.load("model.pth", map_location=device)) # Load the trained weights (.pth)
34
  model.to(device) # Move model to CPU (even if you have a GPU)
35
 
36
+ checkpoint = torch.load('model.pth', map_location='cpu')
37
+
38
+ # Load the model weights
39
+ model.load_state_dict(checkpoint['model_state_dict'])
40
+
41
+ # If you need to resume training, load optimizer and scheduler states
42
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
43
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
44
+
45
+ # If you want to resume from a specific epoch
46
+ epoch = checkpoint['epoch']
47
+
48
+ # Set the model to evaluation mode (for inference)
49
+ model.eval()
50
+
51
+ # model.eval() # Set model to evaluation mode
52
 
53
  # Define the transformation required for the input image
54
  transform = transforms.Compose([