Ahmed-El-Sharkawy commited on
Commit
7b09497
·
verified ·
1 Parent(s): 606f7d3

Upload 7 files

Browse files
Classify_product.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision.transforms as transforms
6
+ import torchvision.models as models
7
+ import numpy as np
8
+
9
+ # Set device
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+
12
+ # Load the main classifier (Main_Classifier_best_model.pth)
13
+ main_model = models.resnet18(pretrained=False)
14
+ num_ftrs = main_model.fc.in_features
15
+ main_model.fc = nn.Linear(num_ftrs, 3) # 3 classes: Soda drinks, Clothing, Mobile Phones
16
+ main_model.load_state_dict(torch.load('Saved Model\Main_Classifier_best_model.pth', map_location=device))
17
+ main_model = main_model.to(device)
18
+ main_model.eval()
19
+
20
+ # Define class names for the main classifier based on folder structure
21
+ main_class_names = ['Clothing', 'Mobile Phones', 'Soda drinks']
22
+
23
+ # Sub-classifier models
24
+ def load_soda_drinks_model():
25
+ model = models.resnet18(pretrained=False)
26
+ num_ftrs = model.fc.in_features
27
+ model.fc = nn.Linear(num_ftrs, 3) # 3 classes: Miranda, Pepsi, Seven Up
28
+ model.load_state_dict(torch.load('Saved Model\Soda_drinks_best_model.pth', map_location=device))
29
+ model = model.to(device)
30
+ model.eval()
31
+ return model
32
+
33
+ def load_clothing_model():
34
+ model = models.resnet18(pretrained=False)
35
+ num_ftrs = model.fc.in_features
36
+ model.fc = nn.Linear(num_ftrs, 2) # 2 classes: Pants, T-Shirt
37
+ model.load_state_dict(torch.load('Saved Model\Clothes_best_model.pth', map_location=device))
38
+ model = model.to(device)
39
+ model.eval()
40
+ return model
41
+
42
+ def load_mobile_phones_model():
43
+ model = models.resnet18(pretrained=False)
44
+ num_ftrs = model.fc.in_features
45
+ model.fc = nn.Linear(num_ftrs, 2) # 2 classes: Apple, Samsung
46
+ model.load_state_dict(torch.load('Saved Model\Phone_best_model.pth', map_location=device))
47
+ model = model.to(device)
48
+ model.eval()
49
+ return model
50
+
51
+ def convert_to_rgb(image):
52
+ """
53
+ Converts 'P' mode images with transparency to 'RGBA', and then to 'RGB'.
54
+ This is to avoid transparency issues during model training.
55
+ """
56
+ if image.mode in ('P', 'RGBA'):
57
+ return image.convert('RGB')
58
+ return image
59
+
60
+ # Define preprocessing transformations (same used during training)
61
+ preprocess = transforms.Compose([
62
+ transforms.Lambda(convert_to_rgb),
63
+ transforms.Resize((224, 224)),
64
+ transforms.ToTensor(),
65
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet normalization
66
+ ])
67
+
68
+ # Streamlit App Interface
69
+ st.title("Main Classifier and Sub-Classifier System")
70
+ st.write("Upload an image to classify whether it belongs to Clothing, Mobile Phones, or Soda Drinks. Based on the prediction, it will further classify within the subcategory.")
71
+
72
+ # Image uploader in Streamlit
73
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
74
+
75
+ if uploaded_file is not None:
76
+ # Open the image using PIL
77
+ image = Image.open(uploaded_file)
78
+
79
+ # Display the uploaded image
80
+ st.image(image, caption='Uploaded Image', use_column_width=True)
81
+ st.write("")
82
+ st.write("Classifying...")
83
+
84
+ # Preprocess the image
85
+ input_image = preprocess(image).unsqueeze(0).to(device)
86
+
87
+ # Perform inference with the main classifier
88
+ with torch.no_grad():
89
+ output = main_model(input_image)
90
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
91
+ confidence, predicted_class = torch.max(probabilities, 0)
92
+
93
+ # Display the main classifier result
94
+ main_prediction = main_class_names[predicted_class]
95
+ st.write(f"**Main Predicted Class:** {main_prediction}")
96
+ st.write(f"**Confidence:** {confidence.item():.4f}")
97
+
98
+ # Load and apply the sub-classifier based on the main classification
99
+ if main_prediction == 'Soda drinks':
100
+ st.write("Loading Soda Drinks Model...")
101
+ soda_model = load_soda_drinks_model()
102
+ sub_class_names = ['Miranda', 'Pepsi', 'Seven Up']
103
+ elif main_prediction == 'Clothing':
104
+ st.write("Loading Clothing Model...")
105
+ clothing_model = load_clothing_model()
106
+ sub_class_names = ['Pants', 'T-Shirt']
107
+ elif main_prediction == 'Mobile Phones':
108
+ st.write("Loading Mobile Phones Model...")
109
+ phones_model = load_mobile_phones_model()
110
+ sub_class_names = ['Apple', 'Samsung']
111
+
112
+ # Perform inference with the sub-classifier
113
+ with torch.no_grad():
114
+ if main_prediction == 'Soda drinks':
115
+ sub_output = soda_model(input_image)
116
+ elif main_prediction == 'Clothing':
117
+ sub_output = clothing_model(input_image)
118
+ elif main_prediction == 'Mobile Phones':
119
+ sub_output = phones_model(input_image)
120
+
121
+ sub_probabilities = torch.nn.functional.softmax(sub_output[0], dim=0)
122
+ sub_confidence, sub_predicted_class = torch.max(sub_probabilities, 0)
123
+
124
+ # Display the sub-classifier result
125
+ st.write(f"**Sub Predicted Class:** {sub_class_names[sub_predicted_class]}")
126
+ st.write(f"**Confidence:** {sub_confidence.item():.4f}")
Clothes_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eddc1323a0b208ebb6dda79a8fe273e95ab0430ce6ab095be71849cf0b9bb010
3
+ size 44791416
Main_Classifier_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e34b538419ac07e0f4f8182b5ae479eb6cc2ef77332fe444f6a05866d1eb9e56
3
+ size 44791416
Phone_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b16817f04fbec8c2367c0dcc98b40572f6d56ad8246d35ba116316fea40fd8b8
3
+ size 44789368
Soda_drinks_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b8d0686fdf1449a52e9d9dedfd3f1bdbb677aa1c39c91bcfd02ff138784c5f5
3
+ size 44791416
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit==1.34.0
2
+ Pillow==10.3.0
3
+ torch==2.4.1
4
+ torchvision==0.19.1
5
+ numpy==1.26.4
6
+
7
+
8
+
9
+
10
+
tshirt_pants_classifier.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b06bed9897948c4abc9c5e15d075308c06886ddb79e9a886eb705e132da625c7
3
+ size 85320024