product-catalog / appv0.py
samnji's picture
step 6
e0f8afd
import gradio as gr
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset
import torch
# Load the pre-trained CLIP model and its tokenizer
model_name = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_name)
model = CLIPModel.from_pretrained(model_name)
# Load the fashion product images dataset from Hugging Face
dataset = load_dataset("ashraq/fashion-product-images-small")
deepfashion_database = dataset["train"]
# Define the preprocessing function for images
def preprocess_image(image):
preprocess = Compose([
Resize(256, interpolation=Image.BICUBIC),
CenterCrop(224),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
return preprocess(image).unsqueeze(0)
# Define a function to process the image and text inputs
def initial_query(image, text):
return process_query(image, text, deepfashion_database)
def process_query(image, text, database):
image_tensor = preprocess_image(image)
inputs = processor(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(input_ids=inputs["input_ids"], pixel_values=image_tensor)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=-1)
similarities = probs.squeeze()
product_scores = []
for product in database:
product_image = Image.open(product["image_path"]).convert("RGB")
product_image_tensor = preprocess_image(product_image)
product_text = product["description"]
product_inputs = processor(product_text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
product_outputs = model(input_ids=product_inputs["input_ids"], pixel_values=product_image_tensor)
product_logits_per_image = product_outputs.logits_per_image
product_probs = product_logits_per_image.softmax(dim=-1)
product_similarity = product_probs.squeeze().item()
product_scores.append((product, product_similarity))
top_3_products = sorted(product_scores, key=lambda x: x[1], reverse=True)[:3]
return top_3_products
def refine_query(selected_product_index, additional_text, initial_results):
selected_product = initial_results[selected_product_index]
modified_description = selected_product["description"] + " " + additional_text
refined_product = {"description": modified_description, "image_path": selected_product["image_path"]}
refined_database = [product for i, product in enumerate(initial_results) if i != selected_product_index]
refined_database.append(refined_product)
return process_query(Image.open(selected_product["image_path"]).convert("RGB"), modified_description, refined_database)
def generate_output_html(products):
html_output = "<ol>"
for product in products:
html_output += f'<li><img src="{product[0]["image_path"]}" width="100" height="100"><br>{product[0]["description"]}</li>'
html_output += "</ol>"
return html_output
def initial_query_wrapper(image, text):
top_3_products = initial_query(image, text)
return generate_output_html(top_3_products),
def refine_query_wrapper(selected_product_index, additional_text, initial_results):
top_3_products = refine_query(selected_product_index, additional_text, initial_results)
return generate_output_html(top_3_products),
def product_search_wrapper(image=None, text=None, selected_product_index=None, additional_text=None):
if image is not None and text is not None:
top_3_products = initial_query(image, text)
return generate_output_html(top_3_products),
elif selected_product_index is not None and additional_text is not None:
top_3_products = refine_query(selected_product_index, additional_text)
return generate_output_html(top_3_products),
else:
return "",
iface = gr.Interface(
fn=product_search_wrapper,
inputs=[
gr.inputs.Image(optional=True),
gr.inputs.Textbox(lines=3, label="Initial Text Query", optional=True),
gr.inputs.Radio(["0", "1", "2"], label="Select Product Index", optional=True),
gr.inputs.Textbox(lines=3, label="Additional Text Query", optional=True)
],
outputs=[
gr.outputs.HTML(label="Results")
],
title="Product Search",
description="Find the best matching products using images and text queries.",
layout="vertical"
)
# iface = gr.Interface(
# fn=[initial_query_wrapper, refine_query_wrapper],
# inputs=[
# [gr.inputs.Image(), gr.inputs.Textbox(lines=3, label="Initial Text Query")],
# [gr.inputs.Radio(["0", "1", "2"], label="Select Product Index"), gr.inputs.Textbox(lines=3, label="Additional Text Query"), gr.inputs.Hidden(initial_results="initial_query")]
# ],
# outputs=[
# gr.outputs.HTML(label="Top 3 Matches"),
# gr.outputs.HTML(label="Refined Top 3 Matches")
# ],
# title="Product Search",
# description="Find the best matching products using images and text queries.",
# layout="vertical"
# )
iface.launch()