import gradio as gr from PIL import Image from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from transformers import CLIPProcessor, CLIPModel from datasets import load_dataset import torch # Load the pre-trained CLIP model and its tokenizer model_name = "openai/clip-vit-base-patch32" processor = CLIPProcessor.from_pretrained(model_name) model = CLIPModel.from_pretrained(model_name) # Load the fashion product images dataset from Hugging Face dataset = load_dataset("ashraq/fashion-product-images-small") deepfashion_database = dataset["train"] # Define the preprocessing function for images def preprocess_image(image): preprocess = Compose([ Resize(256, interpolation=Image.BICUBIC), CenterCrop(224), ToTensor(), Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) return preprocess(image).unsqueeze(0) # Modify initial_query function to return the same input the user gives 3 times def initial_query(image, text): input_product = {"description": text, "image_path": None} top_3_products = [(input_product, 1), (input_product, 1), (input_product, 1)] return top_3_products # Keep the rest of the code unchanged def generate_output_html(products): html_output = "
    " for product in products: html_output += f'
  1. {product[0]["description"]}
  2. ' html_output += "
" return html_output def initial_query_wrapper(image, text): top_3_products = initial_query(image, text) return generate_output_html(top_3_products), def product_search_wrapper(image=None, text=None, selected_product_index=None, additional_text=None): if image is not None or text is not None: top_3_products = initial_query(image, text) return generate_output_html(top_3_products), else: return "", iface = gr.Interface( fn=product_search_wrapper, inputs=[ gr.inputs.Image(optional=True), gr.inputs.Textbox(lines=3, label="Initial Text Query", optional=True), ], outputs=[ gr.outputs.HTML(label="Results") ], title="Product Search", description="Find the best matching products using images and text queries.", layout="vertical" ) iface.launch()