File size: 5,308 Bytes
4be0554
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset
import torch

# Load the pre-trained CLIP model and its tokenizer
model_name = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_name)
model = CLIPModel.from_pretrained(model_name)

# Load the fashion product images dataset from Hugging Face
dataset = load_dataset("ashraq/fashion-product-images-small")
deepfashion_database = dataset["train"]

# Define the preprocessing function for images
def preprocess_image(image):
    preprocess = Compose([
        Resize(256, interpolation=Image.BICUBIC),
        CenterCrop(224),
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])
    return preprocess(image).unsqueeze(0)

# Define a function to process the image and text inputs
def initial_query(image, text):
    return process_query(image, text, deepfashion_database)

def process_query(image, text, database):
    image_tensor = preprocess_image(image)
    inputs = processor(text, return_tensors="pt", padding=True, truncation=True)

    with torch.no_grad():
        outputs = model(input_ids=inputs["input_ids"], pixel_values=image_tensor)
        logits_per_image = outputs.logits_per_image
        probs = logits_per_image.softmax(dim=-1)
        similarities = probs.squeeze()

    product_scores = []
    for product in database:
        product_image = Image.open(product["image_path"]).convert("RGB")
        product_image_tensor = preprocess_image(product_image)
        product_text = product["description"]
        product_inputs = processor(product_text, return_tensors="pt", padding=True, truncation=True)

        with torch.no_grad():
            product_outputs = model(input_ids=product_inputs["input_ids"], pixel_values=product_image_tensor)
            product_logits_per_image = product_outputs.logits_per_image
            product_probs = product_logits_per_image.softmax(dim=-1)
            product_similarity = product_probs.squeeze().item()

        product_scores.append((product, product_similarity))

    top_3_products = sorted(product_scores, key=lambda x: x[1], reverse=True)[:3]
    return top_3_products

def refine_query(selected_product_index, additional_text, initial_results):
    selected_product = initial_results[selected_product_index]
    modified_description = selected_product["description"] + " " + additional_text
    refined_product = {"description": modified_description, "image_path": selected_product["image_path"]}
    refined_database = [product for i, product in enumerate(initial_results) if i != selected_product_index]
    refined_database.append(refined_product)
    return process_query(Image.open(selected_product["image_path"]).convert("RGB"), modified_description, refined_database)





def generate_output_html(products):
    html_output = "<ol>"
    for product in products:
        html_output += f'<li><img src="{product[0]["image_path"]}" width="100" height="100"><br>{product[0]["description"]}</li>'
    html_output += "</ol>"
    return html_output

def initial_query_wrapper(image, text):
    top_3_products = initial_query(image, text)
    return generate_output_html(top_3_products),

def refine_query_wrapper(selected_product_index, additional_text, initial_results):
    top_3_products = refine_query(selected_product_index, additional_text, initial_results)
    return generate_output_html(top_3_products),

def product_search_wrapper(image=None, text=None, selected_product_index=None, additional_text=None):
    if image is not None and text is not None:
        top_3_products = initial_query(image, text)
        return generate_output_html(top_3_products),
    elif selected_product_index is not None and additional_text is not None:
        top_3_products = refine_query(selected_product_index, additional_text)
        return generate_output_html(top_3_products),
    else:
        return "",

iface = gr.Interface(
    fn=product_search_wrapper,
    inputs=[
        gr.inputs.Image(optional=True),
        gr.inputs.Textbox(lines=3, label="Initial Text Query", optional=True),
        gr.inputs.Radio(["0", "1", "2"], label="Select Product Index", optional=True),
        gr.inputs.Textbox(lines=3, label="Additional Text Query", optional=True)
    ],
    outputs=[
        gr.outputs.HTML(label="Results")
    ],
    title="Product Search",
    description="Find the best matching products using images and text queries.",
    layout="vertical"
)

# iface = gr.Interface(
#     fn=[initial_query_wrapper, refine_query_wrapper],
#     inputs=[
#         [gr.inputs.Image(), gr.inputs.Textbox(lines=3, label="Initial Text Query")],
#         [gr.inputs.Radio(["0", "1", "2"], label="Select Product Index"), gr.inputs.Textbox(lines=3, label="Additional Text Query"), gr.inputs.Hidden(initial_results="initial_query")]
#     ],
#     outputs=[
#         gr.outputs.HTML(label="Top 3 Matches"),
#         gr.outputs.HTML(label="Refined Top 3 Matches")
#     ],
#     title="Product Search",
#     description="Find the best matching products using images and text queries.",
#     layout="vertical"
# )

iface.launch()