File size: 3,627 Bytes
f02b11f 2300238 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
from pinecone import Pinecone, ServerlessSpec
from pinecone_text.sparse import BM25Encoder
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import torch
from io import BytesIO
from base64 import b64encode
from tqdm.auto import tqdm
from PIL import Image
import gradio as gr
from constants import *
from search import SearchItem
from fastapi import FastAPI
# initialize connection to pinecone (get API key at app.pinecone.io)
api_key = PINECONE_API_KEY or os.getenv(PINECONE_API_KEY) # or "PINECONE_API_KEY"
# find your environment next to the api key in pinecone console
env = PINECONE_ENVIRONMENT or os.getenv(PINECONE_ENVIRONMENT) # or "PINECONE_ENVIRONMENT"
fashion_processor = SearchItem(api_key, env)
def retrieve_images(query, image=None):
if image:
# If image is provided, use retrieve_image_from_image function
return retrieve_image_from_image(image, query)
else:
# If image is not provided, use retrieve_image_from_query function
return retrieve_image_from_query(query)
def retrieve_image_from_query(query):
# create sparse and dense vectors
sparse = fashion_processor.bm25.encode_queries(query)
dense = fashion_processor.clip_model.encode(query).tolist()
hdense, hsparse = fashion_processor.hybrid_scale(dense, sparse)
result = fashion_processor.index.query(
top_k=10,
vector=hdense,
sparse_vector=hsparse,
include_metadata=True
)
imgs = [fashion_processor.images[int(r["id"])] for r in result["matches"]]
return imgs
def retrieve_image_from_image(image, query):
try:
# create sparse and dense vectors
sparse = fashion_processor.bm25.encode_queries(query)
w, h = 60, 80
image = Image.open(image.name).resize((w, h))
dense = fashion_processor.clip_model.encode(image).tolist()
hdense, hsparse = fashion_processor.hybrid_scale(dense, sparse)
result = fashion_processor.index.query(
top_k=10,
vector=hdense,
sparse_vector=hsparse,
include_metadata=True
)
imgs = [fashion_processor.images[int(r["id"])] for r in result["matches"]]
return imgs
except Exception as e:
# print(f"Error processing image: {e}")
print(e)
return None
def show_img(image):
return image.name if image else "No image provided"
with gr.Blocks() as demo:
gr.Markdown(
"""
# Shopping Search Engine
Look for the ideal clothing items 😎
""")
with gr.Row():
with gr.Column():
query = gr.Textbox(placeholder="Search Items")
gr.HTML("OR")
photo = gr.Image()
with gr.Row():
button = gr.UploadButton(label="Upload Image", file_types=["image"])
textbox = gr.Textbox(placeholder="Additional Details ?")
submit_button = gr.Button(text="Submit")
with gr.Column():
gallery = gr.Gallery().style(
object_fit='contain',
height='auto',
preview=True
)
query.submit(fn=lambda query: retrieve_images(query), inputs=[query], outputs=[gallery])
submit_button.click(fn=lambda image, query: show_img(image), inputs=[button, textbox], outputs=[photo]) \
.then(fn=lambda image, query: retrieve_images(query, image), inputs=[button, textbox], outputs=[gallery])
#demo.launch(server_name="0.0.0.0", server_port=8000)
demo.launch() |