|
|
|
|
|
|
|
import os |
|
from pinecone import Pinecone, ServerlessSpec |
|
from pinecone_text.sparse import BM25Encoder |
|
from datasets import load_dataset |
|
from sentence_transformers import SentenceTransformer |
|
import torch |
|
from io import BytesIO |
|
from base64 import b64encode |
|
from tqdm.auto import tqdm |
|
from PIL import Image |
|
import gradio as gr |
|
from constants import * |
|
|
|
from search import SearchItem |
|
|
|
from fastapi import FastAPI |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api_key = PINECONE_API_KEY or os.getenv(PINECONE_API_KEY) |
|
|
|
env = PINECONE_ENVIRONMENT or os.getenv(PINECONE_ENVIRONMENT) |
|
|
|
fashion_processor = SearchItem(api_key, env) |
|
|
|
|
|
def retrieve_images(query, image=None): |
|
if image: |
|
|
|
return retrieve_image_from_image(image, query) |
|
else: |
|
|
|
return retrieve_image_from_query(query) |
|
|
|
|
|
|
|
def retrieve_image_from_query(query): |
|
|
|
|
|
sparse = fashion_processor.bm25.encode_queries(query) |
|
dense = fashion_processor.clip_model.encode(query).tolist() |
|
hdense, hsparse = fashion_processor.hybrid_scale(dense, sparse) |
|
|
|
result = fashion_processor.index.query( |
|
top_k=10, |
|
vector=hdense, |
|
sparse_vector=hsparse, |
|
include_metadata=True |
|
) |
|
|
|
imgs = [fashion_processor.images[int(r["id"])] for r in result["matches"]] |
|
|
|
return imgs |
|
|
|
|
|
def retrieve_image_from_image(image, query): |
|
|
|
try: |
|
|
|
sparse = fashion_processor.bm25.encode_queries(query) |
|
w, h = 60, 80 |
|
image = Image.open(image.name).resize((w, h)) |
|
dense = fashion_processor.clip_model.encode(image).tolist() |
|
hdense, hsparse = fashion_processor.hybrid_scale(dense, sparse) |
|
|
|
|
|
result = fashion_processor.index.query( |
|
top_k=10, |
|
vector=hdense, |
|
sparse_vector=hsparse, |
|
include_metadata=True |
|
) |
|
|
|
imgs = [fashion_processor.images[int(r["id"])] for r in result["matches"]] |
|
|
|
return imgs |
|
|
|
except Exception as e: |
|
|
|
print(e) |
|
return None |
|
|
|
|
|
|
|
def show_img(image): |
|
return image.name if image else "No image provided" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# Shopping Search Engine |
|
|
|
Look for the ideal clothing items 😎 |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
query = gr.Textbox(placeholder="Search Items") |
|
gr.HTML("OR") |
|
photo = gr.Image() |
|
with gr.Row(): |
|
button = gr.UploadButton(label="Upload Image", file_types=["image"]) |
|
textbox = gr.Textbox(placeholder="Additional Details ?") |
|
submit_button = gr.Button(text="Submit") |
|
|
|
with gr.Column(): |
|
gallery = gr.Gallery().style( |
|
object_fit='contain', |
|
height='auto', |
|
preview=True |
|
) |
|
|
|
query.submit(fn=lambda query: retrieve_images(query), inputs=[query], outputs=[gallery]) |
|
submit_button.click(fn=lambda image, query: show_img(image), inputs=[button, textbox], outputs=[photo]) \ |
|
.then(fn=lambda image, query: retrieve_images(query, image), inputs=[button, textbox], outputs=[gallery]) |
|
|
|
|
|
demo.launch() |