Ahmed-El-Sharkawy's picture
Rename Classify_product.py to app.py
90c6904 verified
import streamlit as st
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the main classifier (Main_Classifier_best_model.pth)
main_model = models.resnet18(pretrained=False)
num_ftrs = main_model.fc.in_features
main_model.fc = nn.Linear(num_ftrs, 3) # 3 classes: Soda drinks, Clothing, Mobile Phones
main_model.load_state_dict(torch.load('Main_Classifier_best_model.pth', map_location=device))
main_model = main_model.to(device)
main_model.eval()
# Define class names for the main classifier based on folder structure
main_class_names = ['Clothing', 'Mobile Phones', 'Soda drinks']
# Sub-classifier models
def load_soda_drinks_model():
model = models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 3) # 3 classes: Miranda, Pepsi, Seven Up
model.load_state_dict(torch.load('Soda_drinks_best_model.pth', map_location=device))
model = model.to(device)
model.eval()
return model
def load_clothing_model():
model = models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2) # 2 classes: Pants, T-Shirt
model.load_state_dict(torch.load('Clothes_best_model.pth', map_location=device))
model = model.to(device)
model.eval()
return model
def load_mobile_phones_model():
model = models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2) # 2 classes: Apple, Samsung
model.load_state_dict(torch.load('Phone_best_model.pth', map_location=device))
model = model.to(device)
model.eval()
return model
def convert_to_rgb(image):
"""
Converts 'P' mode images with transparency to 'RGBA', and then to 'RGB'.
This is to avoid transparency issues during model training.
"""
if image.mode in ('P', 'RGBA'):
return image.convert('RGB')
return image
# Define preprocessing transformations (same used during training)
preprocess = transforms.Compose([
transforms.Lambda(convert_to_rgb),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet normalization
])
# Streamlit App Interface
st.title("Main Classifier and Sub-Classifier System")
st.write("Upload an image to classify whether it belongs to Clothing, Mobile Phones, or Soda Drinks. Based on the prediction, it will further classify within the subcategory.")
# Image uploader in Streamlit
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Open the image using PIL
image = Image.open(uploaded_file)
# Display the uploaded image
st.image(image, caption='Uploaded Image', use_column_width=True)
st.write("")
st.write("Classifying...")
# Preprocess the image
input_image = preprocess(image).unsqueeze(0).to(device)
# Perform inference with the main classifier
with torch.no_grad():
output = main_model(input_image)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
confidence, predicted_class = torch.max(probabilities, 0)
# Display the main classifier result
main_prediction = main_class_names[predicted_class]
st.write(f"**Main Predicted Class:** {main_prediction}")
st.write(f"**Confidence:** {confidence.item():.4f}")
# Load and apply the sub-classifier based on the main classification
if main_prediction == 'Soda drinks':
st.write("Loading Soda Drinks Model...")
soda_model = load_soda_drinks_model()
sub_class_names = ['Miranda', 'Pepsi', 'Seven Up']
elif main_prediction == 'Clothing':
st.write("Loading Clothing Model...")
clothing_model = load_clothing_model()
sub_class_names = ['Pants', 'T-Shirt']
elif main_prediction == 'Mobile Phones':
st.write("Loading Mobile Phones Model...")
phones_model = load_mobile_phones_model()
sub_class_names = ['Apple', 'Samsung']
# Perform inference with the sub-classifier
with torch.no_grad():
if main_prediction == 'Soda drinks':
sub_output = soda_model(input_image)
elif main_prediction == 'Clothing':
sub_output = clothing_model(input_image)
elif main_prediction == 'Mobile Phones':
sub_output = phones_model(input_image)
sub_probabilities = torch.nn.functional.softmax(sub_output[0], dim=0)
sub_confidence, sub_predicted_class = torch.max(sub_probabilities, 0)
# Display the sub-classifier result
st.write(f"**Sub Predicted Class:** {sub_class_names[sub_predicted_class]}")
st.write(f"**Confidence:** {sub_confidence.item():.4f}")