File size: 2,277 Bytes
c37abe8
47d7fa1
c37abe8
47d7fa1
506e7cd
47d7fa1
 
506e7cd
47d7fa1
 
 
 
506e7cd
 
 
 
c37abe8
4be0554
c37abe8
 
 
 
 
4be0554
c37abe8
a0c768e
c37abe8
 
 
 
 
fbbda08
c37abe8
fbbda08
c37abe8
 
 
 
 
 
fbbda08
c37abe8
 
 
fbbda08
c37abe8
 
 
 
 
 
fbbda08
c37abe8
 
 
 
 
 
 
 
 
 
 
 
 
fbbda08
c37abe8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset
import torch

# Load the pre-trained CLIP model and its tokenizer
model_name = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_name)
model = CLIPModel.from_pretrained(model_name)

# Load the fashion product images dataset from Hugging Face
dataset = load_dataset("ashraq/fashion-product-images-small")
deepfashion_database = dataset["train"]

# Define the preprocessing function for images
def preprocess_image(image):
    preprocess = Compose([
        Resize(256, interpolation=Image.BICUBIC),
        CenterCrop(224),
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])
    return preprocess(image).unsqueeze(0)

# Modify initial_query function to return the same input the user gives 3 times
def initial_query(image, text):
    input_product = {"description": text, "image_path": None}
    top_3_products = [(input_product, 1), (input_product, 1), (input_product, 1)]
    return top_3_products

# Keep the rest of the code unchanged

def generate_output_html(products):
    html_output = "<ol>"
    for product in products:
        html_output += f'<li>{product[0]["description"]}</li>'
    html_output += "</ol>"
    return html_output

def initial_query_wrapper(image, text):
    top_3_products = initial_query(image, text)
    return generate_output_html(top_3_products),

def product_search_wrapper(image=None, text=None, selected_product_index=None, additional_text=None):
    if image is not None or text is not None:
        top_3_products = initial_query(image, text)
        return generate_output_html(top_3_products),
    else:
        return "",

iface = gr.Interface(
    fn=product_search_wrapper,
    inputs=[
        gr.inputs.Image(optional=True),
        gr.inputs.Textbox(lines=3, label="Initial Text Query", optional=True),
    ],
    outputs=[
        gr.outputs.HTML(label="Results")
    ],
    title="Product Search",
    description="Find the best matching products using images and text queries.",
    layout="vertical"
)

iface.launch()