Ahmed-El-Sharkawy commited on
Commit
90c6904
·
verified ·
1 Parent(s): 7b09497

Rename Classify_product.py to app.py

Browse files
Files changed (1) hide show
  1. Classify_product.py → app.py +4 -4
Classify_product.py → app.py RENAMED
@@ -13,7 +13,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
  main_model = models.resnet18(pretrained=False)
14
  num_ftrs = main_model.fc.in_features
15
  main_model.fc = nn.Linear(num_ftrs, 3) # 3 classes: Soda drinks, Clothing, Mobile Phones
16
- main_model.load_state_dict(torch.load('Saved Model\Main_Classifier_best_model.pth', map_location=device))
17
  main_model = main_model.to(device)
18
  main_model.eval()
19
 
@@ -25,7 +25,7 @@ def load_soda_drinks_model():
25
  model = models.resnet18(pretrained=False)
26
  num_ftrs = model.fc.in_features
27
  model.fc = nn.Linear(num_ftrs, 3) # 3 classes: Miranda, Pepsi, Seven Up
28
- model.load_state_dict(torch.load('Saved Model\Soda_drinks_best_model.pth', map_location=device))
29
  model = model.to(device)
30
  model.eval()
31
  return model
@@ -34,7 +34,7 @@ def load_clothing_model():
34
  model = models.resnet18(pretrained=False)
35
  num_ftrs = model.fc.in_features
36
  model.fc = nn.Linear(num_ftrs, 2) # 2 classes: Pants, T-Shirt
37
- model.load_state_dict(torch.load('Saved Model\Clothes_best_model.pth', map_location=device))
38
  model = model.to(device)
39
  model.eval()
40
  return model
@@ -43,7 +43,7 @@ def load_mobile_phones_model():
43
  model = models.resnet18(pretrained=False)
44
  num_ftrs = model.fc.in_features
45
  model.fc = nn.Linear(num_ftrs, 2) # 2 classes: Apple, Samsung
46
- model.load_state_dict(torch.load('Saved Model\Phone_best_model.pth', map_location=device))
47
  model = model.to(device)
48
  model.eval()
49
  return model
 
13
  main_model = models.resnet18(pretrained=False)
14
  num_ftrs = main_model.fc.in_features
15
  main_model.fc = nn.Linear(num_ftrs, 3) # 3 classes: Soda drinks, Clothing, Mobile Phones
16
+ main_model.load_state_dict(torch.load('Main_Classifier_best_model.pth', map_location=device))
17
  main_model = main_model.to(device)
18
  main_model.eval()
19
 
 
25
  model = models.resnet18(pretrained=False)
26
  num_ftrs = model.fc.in_features
27
  model.fc = nn.Linear(num_ftrs, 3) # 3 classes: Miranda, Pepsi, Seven Up
28
+ model.load_state_dict(torch.load('Soda_drinks_best_model.pth', map_location=device))
29
  model = model.to(device)
30
  model.eval()
31
  return model
 
34
  model = models.resnet18(pretrained=False)
35
  num_ftrs = model.fc.in_features
36
  model.fc = nn.Linear(num_ftrs, 2) # 2 classes: Pants, T-Shirt
37
+ model.load_state_dict(torch.load('Clothes_best_model.pth', map_location=device))
38
  model = model.to(device)
39
  model.eval()
40
  return model
 
43
  model = models.resnet18(pretrained=False)
44
  num_ftrs = model.fc.in_features
45
  model.fc = nn.Linear(num_ftrs, 2) # 2 classes: Apple, Samsung
46
+ model.load_state_dict(torch.load('Phone_best_model.pth', map_location=device))
47
  model = model.to(device)
48
  model.eval()
49
  return model