anjikum commited on
Commit
22a11e8
·
verified ·
1 Parent(s): 4ec6f83

added changes to model

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -28,6 +28,8 @@ device = torch.device('cpu')
28
 
29
  # Load your trained ResNet-50 model (or any custom architecture)
30
  model = models.resnet50(pretrained=False) # Load the ResNet-50 architecture
 
 
31
  model.load_state_dict(torch.load("model.pth", map_location=device)) # Load the trained weights (.pth)
32
  model.to(device) # Move model to CPU (even if you have a GPU)
33
 
@@ -42,8 +44,7 @@ transform = transforms.Compose([
42
  ])
43
 
44
  # Define the labels for ImageNet (or your specific dataset labels)
45
- LABELS = ["class_1", "class_2", "class_3", "class_4", "class_5", # Replace with your classes
46
- "class_6", "class_7", "class_8", "class_9", "class_10"]
47
 
48
  # Define the prediction function
49
  def predict(image):
 
28
 
29
  # Load your trained ResNet-50 model (or any custom architecture)
30
  model = models.resnet50(pretrained=False) # Load the ResNet-50 architecture
31
+ model.fc = nn.Linear(model.fc.in_features, 1000)
32
+
33
  model.load_state_dict(torch.load("model.pth", map_location=device)) # Load the trained weights (.pth)
34
  model.to(device) # Move model to CPU (even if you have a GPU)
35
 
 
44
  ])
45
 
46
  # Define the labels for ImageNet (or your specific dataset labels)
47
+ LABELS = [f"class_{k}" for k in range(1,1001)]
 
48
 
49
  # Define the prediction function
50
  def predict(image):