added changes to model
Browse files
app.py
CHANGED
@@ -28,6 +28,8 @@ device = torch.device('cpu')
|
|
28 |
|
29 |
# Load your trained ResNet-50 model (or any custom architecture)
|
30 |
model = models.resnet50(pretrained=False) # Load the ResNet-50 architecture
|
|
|
|
|
31 |
model.load_state_dict(torch.load("model.pth", map_location=device)) # Load the trained weights (.pth)
|
32 |
model.to(device) # Move model to CPU (even if you have a GPU)
|
33 |
|
@@ -42,8 +44,7 @@ transform = transforms.Compose([
|
|
42 |
])
|
43 |
|
44 |
# Define the labels for ImageNet (or your specific dataset labels)
|
45 |
-
LABELS = ["
|
46 |
-
"class_6", "class_7", "class_8", "class_9", "class_10"]
|
47 |
|
48 |
# Define the prediction function
|
49 |
def predict(image):
|
|
|
28 |
|
29 |
# Load your trained ResNet-50 model (or any custom architecture)
|
30 |
model = models.resnet50(pretrained=False) # Load the ResNet-50 architecture
|
31 |
+
model.fc = nn.Linear(model.fc.in_features, 1000)
|
32 |
+
|
33 |
model.load_state_dict(torch.load("model.pth", map_location=device)) # Load the trained weights (.pth)
|
34 |
model.to(device) # Move model to CPU (even if you have a GPU)
|
35 |
|
|
|
44 |
])
|
45 |
|
46 |
# Define the labels for ImageNet (or your specific dataset labels)
|
47 |
+
LABELS = [f"class_{k}" for k in range(1,1001)]
|
|
|
48 |
|
49 |
# Define the prediction function
|
50 |
def predict(image):
|