product-classification-pipeline / model_pipeline.py
krishuggingface's picture
Upload model_pipeline.py with huggingface_hub
464c27a verified
import os
import requests
from io import BytesIO
from PIL import Image
import pandas as pd
import google.generativeai as genai
import matplotlib.pyplot as plt
from google.colab import files
# Set up the Generative AI API key
api_key = os.getenv('GOOGLE_API_KEY') # Use environment variable for API key security
genai.configure(api_key=api_key)
categories = ["Personal Care", "Household Care", "Dairy", "Staples", "Snacks and Beverages", "Packaged Food", "Fruits and Vegetables"]
# Step 1: Download image from URL
def download_image(image_url):
try:
response = requests.get(image_url)
response.raise_for_status() # Check if the request was successful
img = Image.open(BytesIO(response.content))
temp_path = "temp_image.jpg" # Temporary path
img.save(temp_path) # Save the image locally for further use
return temp_path
except Exception as e:
print(f"Error downloading image: {e}")
return None
# Step 2: Upload Image to the API
def upload_image(image_path):
sample_file = genai.upload_file(path=image_path, display_name="Product Image")
print(f"Uploaded file '{sample_file.display_name}' as: {sample_file.uri}")
return sample_file
# Step 3: Display Image
def display_image(image_path):
img = Image.open(image_path)
plt.imshow(img)
plt.axis('off')
plt.show()
# Step 4: Classify image to decide whether it contains fruits/vegetables or other products
def classify_image(sample_file):
model = genai.GenerativeModel(model_name="gemini-1.5-pro-latest")
response = model.generate_content([sample_file, "Does this image contain fruits or vegetables? Answer 'yes' or 'no' only."])
classification = response.text.strip().lower()
return classification == "yes"
# Step 5: Predict freshness (for fruits and vegetables)
def predict_freshness(sample_file):
model = genai.GenerativeModel(model_name="gemini-1.5-pro-latest")
response = model.generate_content([sample_file, "Can you provide the average freshness index (1-10) of the fruits/vegetables in the image. Just output the number."])
try:
freshness_index = int(response.text.strip())
return freshness_index
except ValueError:
print("Error: Unable to convert the response to an integer.")
return None
# Step 6: Generate product details (for other products)
def generate_product_details(sample_file):
model = genai.GenerativeModel(model_name="gemini-1.5-pro-latest")
response = model.generate_content([sample_file,
f"Tell me the name of each product, its category among the following list of categories: {categories}, brand, MRP, manufacturer, expiry date, and quantity in the image. "
"Do not output anything else. Output format for each product: "
"Product Name: [Extracted Product name], Category: [Extracted Category], Brand: [Extracted Brand name], MRP: [Extracted MRP], Manufacturer: [Extracted Manufacturer name], "
"Expiry Date: [Extracted Expiry Date], Quantity: [Extracted Quantity]. Separate the details of each product with one newline character. "
"If some of the information is not available for a product, then output NA for that detail."
])
return response.text.strip() if response else ""
# Step 7: Parse the response into a DataFrame
def parse_response_to_dataframe(response_text):
columns = ["Product Name", "Category", "Brand", "MRP", "Manufacturer", "Expiry Date", "Quantity"]
product_sections = response_text.split("\n")
products_list = []
for product_section in product_sections:
product_details = {col: "NA" for col in columns}
response_parts = product_section.split(", ")
for part in response_parts:
if "Product Name" in part:
product_details["Product Name"] = part.split(": ")[1]
elif "Category" in part:
product_details["Category"] = part.split(": ")[1]
elif "Brand" in part:
product_details["Brand"] = part.split(": ")[1]
elif "MRP" in part:
product_details["MRP"] = part.split(": ")[1]
elif "Manufacturer" in part:
product_details["Manufacturer"] = part.split(": ")[1]
elif "Expiry Date" in part:
product_details["Expiry Date"] = part.split(": ")[1]
elif "Quantity" in part:
product_details["Quantity"] = part.split(": ")[1]
products_list.append(product_details)
return pd.DataFrame(products_list, columns=columns)
# Step 8: Style the DataFrame for better display
def style_dataframe(df):
return df.style.set_properties(**{'text-align': 'center', 'border': '1px solid grey'}) .set_table_styles([{'selector': 'td', 'props': [('border', '1px solid grey')]}], overwrite=False)
# Step 9: Display results (image and styled DataFrame)
def display_results(image_path, styled_df):
display_image(image_path) # Display the image
print("\nProduct Details:\n")
display(styled_df) # Display the styled DataFrame
# Step 10: Save DataFrame to CSV
def save_dataframe_to_csv(df, file_name="product_details.csv"):
df.to_csv(file_name, index=False)
print(f"DataFrame saved to {file_name}")
# Combined Pipeline: Choose action based on image content
def combined_pipeline(image_source, is_url=False):
# Step 1: Download the image if it's a URL
if is_url:
image_path = download_image(image_source)
if not image_path:
print("Failed to download the image.")
return
else:
image_path = image_source
# Step 2: Upload the image
sample_file = upload_image(image_path)
if not sample_file:
print("Error uploading image.")
return
# Step 3: Classify whether the image contains fruits/vegetables
is_fruits_or_vegetables = classify_image(sample_file)
if is_fruits_or_vegetables:
print("Image contains fruits or vegetables. Predicting freshness...")
freshness_index = predict_freshness(sample_file)
if freshness_index is not None:
print(f"The predicted freshness index is: {freshness_index}")
else:
print("Failed to predict freshness.")
else:
print("Image contains products. Extracting details...")
response_text = generate_product_details(sample_file)
if not response_text:
print("No product details generated.")
return
df = parse_response_to_dataframe(response_text)
styled_df = style_dataframe(df)
display_results(image_path, styled_df)
# Save the DataFrame to a CSV file
save_dataframe_to_csv(df, "product_details.csv")
# Download the CSV file
files.download("product_details.csv")