File size: 3,758 Bytes
f02b11f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
910d0c4
f02b11f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
910d0c4
f02b11f
 
 
910d0c4
 
 
f02b11f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
910d0c4
f02b11f
910d0c4
 
f02b11f
 
 
 
910d0c4
f02b11f
 
910d0c4
f02b11f
910d0c4
f02b11f
 
 
 
 
 
 
 
 
 
 
910d0c4
f02b11f
 
b07503f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#!/usr/bin/env python
# -*- coding: utf-8 -*- 

import os
from pinecone import Pinecone, ServerlessSpec
from pinecone_text.sparse import BM25Encoder
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import torch
from io import BytesIO
from base64 import b64encode
from tqdm.auto import tqdm
from PIL import Image
import gradio as gr
from constants import *

from search import SearchItem







# initialize connection to pinecone (get API key at app.pinecone.io)
api_key = PINECONE_API_KEY or os.getenv(PINECONE_API_KEY) # or "PINECONE_API_KEY"
# find your environment next to the api key in pinecone console
env = PINECONE_ENVIRONMENT or os.getenv(PINECONE_ENVIRONMENT) # or "PINECONE_ENVIRONMENT"

fashion_processor = SearchItem(api_key, env)
    

def retrieve_images(query, image=None):
    if image:
        # If image is provided, use retrieve_image_from_image function
        return retrieve_image_from_image(image, query)
    else:
        # If image is not provided, use retrieve_image_from_query function
        return retrieve_image_from_query(query)
    


def retrieve_image_from_query(query):

    # create sparse and dense vectors
    sparse = fashion_processor.bm25.encode_queries(query)
    dense = fashion_processor.clip_model.encode(query).tolist()
    hdense, hsparse = fashion_processor.hybrid_scale(dense, sparse)

    result = fashion_processor.index.query(
        top_k=10,
        vector=hdense,
        sparse_vector=hsparse,
        include_metadata=True
    )

    imgs = [fashion_processor.images[int(r["id"])] for r in result["matches"]]

    return imgs

    
def retrieve_image_from_image(image, query):

    try:
        if query is None:
            query = 'No image'
            
        # create sparse and dense vectors
        sparse = fashion_processor.bm25.encode_queries(query)
        w, h = 60, 80
        image = Image.open(image.name).resize((w, h))
        dense = fashion_processor.clip_model.encode(image).tolist()
        hdense, hsparse = fashion_processor.hybrid_scale(dense, sparse)


        result = fashion_processor.index.query(
            top_k=10,
            vector=hdense,
            sparse_vector=hsparse,
            include_metadata=True
        )

        imgs = [fashion_processor.images[int(r["id"])] for r in result["matches"]]

        return imgs
    
    except Exception as e:
        # print(f"Error processing image: {e}")
        print(e)
        return None



def show_img(image):
    return image.name if image else "No image provided"


with gr.Blocks() as demo:
    gr.Markdown(
    """
    # Shopping Search Engine
    
    Look for the ideal clothing items 😆
    """)


    with gr.Row():
        with gr.Column():

            query = gr.Textbox(placeholder="Search Items")
            gr.HTML("OR LOAD IMAGE AND SPECIFIC TEXT DETAILS")
            photo = gr.Image()
            with gr.Row():
                file_output = gr.File()
                button = gr.UploadButton(label="Upload Image", file_types=["image"])
                button.upload(show_img, button, file_output)
                textbox = gr.Textbox(placeholder="Additional Details ?")
                submit_button = gr.Button(text="Submit")

        with gr.Column():
            gallery = gr.Gallery().style(
                object_fit='contain',
                height='auto',
                preview=True
            )

    query.submit(fn=lambda query: retrieve_images(query), inputs=[query], outputs=[gallery])
    
    submit_button.click(fn=lambda image, query: show_img(image), inputs=[button, textbox], outputs=[photo]) \
        .then(fn=lambda image, query: retrieve_images(query, image), inputs=[button, textbox], outputs=[gallery])
demo.launch()