#!/usr/bin/env python # -*- coding: utf-8 -*- import os from pinecone import Pinecone, ServerlessSpec from pinecone_text.sparse import BM25Encoder from datasets import load_dataset from sentence_transformers import SentenceTransformer import torch from io import BytesIO from base64 import b64encode from tqdm.auto import tqdm from PIL import Image import gradio as gr from constants import * from search import SearchItem from fastapi import FastAPI # initialize connection to pinecone (get API key at app.pinecone.io) api_key = PINECONE_API_KEY or os.getenv(PINECONE_API_KEY) # or "PINECONE_API_KEY" # find your environment next to the api key in pinecone console env = PINECONE_ENVIRONMENT or os.getenv(PINECONE_ENVIRONMENT) # or "PINECONE_ENVIRONMENT" fashion_processor = SearchItem(api_key, env) def retrieve_images(query, image=None): if image: # If image is provided, use retrieve_image_from_image functions return retrieve_image_from_image(image, query) else: # If image is not provided, use retrieve_image_from_query function return retrieve_image_from_query(query) def retrieve_image_from_query(query): # create sparse and dense vectors sparse = fashion_processor.bm25.encode_queries(query) dense = fashion_processor.clip_model.encode(query).tolist() hdense, hsparse = fashion_processor.hybrid_scale(dense, sparse) result = fashion_processor.index.query( top_k=10, vector=hdense, sparse_vector=hsparse, include_metadata=True ) imgs = [fashion_processor.images[int(r["id"])] for r in result["matches"]] return imgs def retrieve_image_from_image(image, query): try: # create sparse and dense vectors sparse = fashion_processor.bm25.encode_queries(query) w, h = 60, 80 image = Image.open(image.name).resize((w, h)) dense = fashion_processor.clip_model.encode(image).tolist() hdense, hsparse = fashion_processor.hybrid_scale(dense, sparse) result = fashion_processor.index.query( top_k=10, vector=hdense, sparse_vector=hsparse, include_metadata=True ) imgs = [fashion_processor.images[int(r["id"])] for r in result["matches"]] return imgs except Exception as e: # print(f"Error processing image: {e}") print(e) return None def show_img(image): return image.name if image else "No image provided" with gr.Blocks() as demo: gr.Markdown( """ # Shopping Search Engine Look for the ideal clothing items 😎 """) with gr.Row(): with gr.Column(): query = gr.Textbox(placeholder="Search Items") gr.HTML("OR") photo = gr.Image() with gr.Row(): button = gr.UploadButton(label="Upload Image", file_types=["image"]) textbox = gr.Textbox(placeholder="Additional Details ?") submit_button = gr.Button(text="Submit") with gr.Column(): gallery = gr.Gallery().style( object_fit='contain', height='auto', preview=True ) query.submit(fn=lambda query: retrieve_images(query), inputs=[query], outputs=[gallery]) submit_button.click(fn=lambda image, query: show_img(image), inputs=[button, textbox], outputs=[photo]) \ .then(fn=lambda image, query: retrieve_images(query, image), inputs=[button, textbox], outputs=[gallery]) #demo.launch(server_name="0.0.0.0", server_port=8000) demo.launch()