Spaces:
Runtime error
Runtime error
import gradio as gr | |
from PIL import Image | |
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize | |
from transformers import CLIPProcessor, CLIPModel | |
from datasets import load_dataset | |
import torch | |
# Load the pre-trained CLIP model and its tokenizer | |
model_name = "openai/clip-vit-base-patch32" | |
processor = CLIPProcessor.from_pretrained(model_name) | |
model = CLIPModel.from_pretrained(model_name) | |
# Load the fashion product images dataset from Hugging Face | |
dataset = load_dataset("ashraq/fashion-product-images-small") | |
deepfashion_database = dataset["train"] | |
# Define the preprocessing function for images | |
def preprocess_image(image): | |
preprocess = Compose([ | |
Resize(256, interpolation=Image.BICUBIC), | |
CenterCrop(224), | |
ToTensor(), | |
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), | |
]) | |
return preprocess(image).unsqueeze(0) | |
# Modify initial_query function to return the same input the user gives 3 times | |
def initial_query(image, text): | |
input_product = {"description": text, "image_path": None} | |
top_3_products = [(input_product, 1), (input_product, 1), (input_product, 1)] | |
return top_3_products | |
# Keep the rest of the code unchanged | |
def generate_output_html(products): | |
html_output = "<ol>" | |
for product in products: | |
html_output += f'<li>{product[0]["description"]}</li>' | |
html_output += "</ol>" | |
return html_output | |
def initial_query_wrapper(image, text): | |
top_3_products = initial_query(image, text) | |
return generate_output_html(top_3_products), | |
def product_search_wrapper(image=None, text=None, selected_product_index=None, additional_text=None): | |
if image is not None or text is not None: | |
top_3_products = initial_query(image, text) | |
return generate_output_html(top_3_products), | |
else: | |
return "", | |
iface = gr.Interface( | |
fn=product_search_wrapper, | |
inputs=[ | |
gr.inputs.Image(optional=True), | |
gr.inputs.Textbox(lines=3, label="Initial Text Query", optional=True), | |
], | |
outputs=[ | |
gr.outputs.HTML(label="Results") | |
], | |
title="Product Search", | |
description="Find the best matching products using images and text queries.", | |
layout="vertical" | |
) | |
iface.launch() | |