Rename Classify_product.py to app.py
Browse files
Classify_product.py → app.py
RENAMED
@@ -13,7 +13,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
13 |
main_model = models.resnet18(pretrained=False)
|
14 |
num_ftrs = main_model.fc.in_features
|
15 |
main_model.fc = nn.Linear(num_ftrs, 3) # 3 classes: Soda drinks, Clothing, Mobile Phones
|
16 |
-
main_model.load_state_dict(torch.load('
|
17 |
main_model = main_model.to(device)
|
18 |
main_model.eval()
|
19 |
|
@@ -25,7 +25,7 @@ def load_soda_drinks_model():
|
|
25 |
model = models.resnet18(pretrained=False)
|
26 |
num_ftrs = model.fc.in_features
|
27 |
model.fc = nn.Linear(num_ftrs, 3) # 3 classes: Miranda, Pepsi, Seven Up
|
28 |
-
model.load_state_dict(torch.load('
|
29 |
model = model.to(device)
|
30 |
model.eval()
|
31 |
return model
|
@@ -34,7 +34,7 @@ def load_clothing_model():
|
|
34 |
model = models.resnet18(pretrained=False)
|
35 |
num_ftrs = model.fc.in_features
|
36 |
model.fc = nn.Linear(num_ftrs, 2) # 2 classes: Pants, T-Shirt
|
37 |
-
model.load_state_dict(torch.load('
|
38 |
model = model.to(device)
|
39 |
model.eval()
|
40 |
return model
|
@@ -43,7 +43,7 @@ def load_mobile_phones_model():
|
|
43 |
model = models.resnet18(pretrained=False)
|
44 |
num_ftrs = model.fc.in_features
|
45 |
model.fc = nn.Linear(num_ftrs, 2) # 2 classes: Apple, Samsung
|
46 |
-
model.load_state_dict(torch.load('
|
47 |
model = model.to(device)
|
48 |
model.eval()
|
49 |
return model
|
|
|
13 |
main_model = models.resnet18(pretrained=False)
|
14 |
num_ftrs = main_model.fc.in_features
|
15 |
main_model.fc = nn.Linear(num_ftrs, 3) # 3 classes: Soda drinks, Clothing, Mobile Phones
|
16 |
+
main_model.load_state_dict(torch.load('Main_Classifier_best_model.pth', map_location=device))
|
17 |
main_model = main_model.to(device)
|
18 |
main_model.eval()
|
19 |
|
|
|
25 |
model = models.resnet18(pretrained=False)
|
26 |
num_ftrs = model.fc.in_features
|
27 |
model.fc = nn.Linear(num_ftrs, 3) # 3 classes: Miranda, Pepsi, Seven Up
|
28 |
+
model.load_state_dict(torch.load('Soda_drinks_best_model.pth', map_location=device))
|
29 |
model = model.to(device)
|
30 |
model.eval()
|
31 |
return model
|
|
|
34 |
model = models.resnet18(pretrained=False)
|
35 |
num_ftrs = model.fc.in_features
|
36 |
model.fc = nn.Linear(num_ftrs, 2) # 2 classes: Pants, T-Shirt
|
37 |
+
model.load_state_dict(torch.load('Clothes_best_model.pth', map_location=device))
|
38 |
model = model.to(device)
|
39 |
model.eval()
|
40 |
return model
|
|
|
43 |
model = models.resnet18(pretrained=False)
|
44 |
num_ftrs = model.fc.in_features
|
45 |
model.fc = nn.Linear(num_ftrs, 2) # 2 classes: Apple, Samsung
|
46 |
+
model.load_state_dict(torch.load('Phone_best_model.pth', map_location=device))
|
47 |
model = model.to(device)
|
48 |
model.eval()
|
49 |
return model
|