Spaces:
Sleeping
Sleeping
File size: 5,443 Bytes
773c328 f4b2f43 78a875d f91916c 0255c17 773c328 a3f47db 0255c17 a3f47db 773c328 53dc0c8 773c328 a8f3621 773c328 53dc0c8 f3145a5 53dc0c8 edaf3fc 5d62c98 d4ac3a0 b8dc911 cdb1c55 b8dc911 80e8aa8 0255c17 edaf3fc 53dc0c8 6b805a8 53dc0c8 12b4459 68fc4fc 12b4459 ed01448 14bc159 12b4459 14bc159 53dc0c8 f3145a5 53dc0c8 64b8613 cdb1c55 64b8613 53dc0c8 64b8613 53dc0c8 64b8613 8bdc672 53dc0c8 64b8613 53dc0c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import gradio as gr
from io import BytesIO
from base64 import b64encode
from pinecone_text.sparse import BM25Encoder
from pinecone import Pinecone
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import os
import re
####################
import pandas as pd
##########################
model = SentenceTransformer('sentence-transformers/clip-ViT-B-32')
fashion = load_dataset("ashraq/fashion-product-images-small", split="train")
###############
fashion_df = pd.DataFrame(fashion)
####################
images = fashion['image']
metadata = fashion.remove_columns('image')
item_list = list(set(metadata['productDisplayName']))
INDEX_NAME = 'srinivas-hybrid-search'
PINECONE_API_KEY = os.getenv('pinecone_api_key')
pinecone = Pinecone(api_key=PINECONE_API_KEY)
index = pinecone.Index(INDEX_NAME)
bm25 = BM25Encoder()
bm25.fit(metadata['productDisplayName'])
def display_result(image_batch, match_batch):
figures = []
for img, title in zip(image_batch, match_batch):
if img.mode != 'RGB':
img = img.convert('RGB')
b = BytesIO()
img.save(b, format='PNG')
img_str = b64encode(b.getvalue()).decode('utf-8')
figures.append(f'''
<figure style="margin: 0; padding: 0; text-align: left;">
<figcaption style="font-weight: bold; margin:0;">{title}</figcaption>
<img src="data:image/png;base64,{img_str}" style="width: 180px; height: 240px; margin: 0;" >
</figure>
''')
html_content = f'''
<div style="display: grid; grid-template-columns: repeat(4, 1fr); gap: 20px; align-items: start;">
{''.join(figures)}
</div>
'''
return html_content
def hybrid_scale(dense, sparse, alpha):
if alpha < 0 or alpha > 1:
raise ValueError("Alpha must be between 0 and 1")
hsparse = {
'indices': sparse['indices'],
'values': [v * (1 - alpha) for v in sparse['values']]
}
hdense = [v * alpha for v in dense]
return hdense, hsparse
def process_input(query, slider_value):
#####################
query=query
print(f"Query: {query}")
search_words = query.lower().split()
# pattern = r"(?=.*\b" + r"\b)(?=.*\b".join(map(re.escape, search_words)) + r"\b)"
pattern = r"(?=.*" + r")(?=.*".join(map(re.escape, search_words)) + r")"
filtered_items = [item for item in item_list if re.search(pattern, item.lower())]
print(f"Filtered items: {filtered_items}")
filtered_df = fashion_df[fashion_df['productDisplayName'].isin(filtered_items)]
#####################
try:
slider_value = float(slider_value)
sparse = bm25.encode_queries(query)
dense = model.encode(query).tolist()
hdense, hsparse = hybrid_scale(dense, sparse, slider_value)
result = index.query(
top_k=12,
vector=hdense,
sparse_vector=hsparse,
include_metadata=True
)
imgs = [images[int(r["id"])] for r in result["matches"]]
matches = [x["metadata"]['productDisplayName'] for x in result["matches"]]
##########
if query in filtered_items:
# exact_match = filtered_df.loc[filtered_df['productDisplayName']==query, 'productDisplayName'].iat[0]
exact_img = filtered_df.loc[filtered_df['productDisplayName']==query, 'image'].iat[0]
imgs.insert(0, exact_img)
matches.insert(0, query)
##########
print(f"No. of matching images: {len(imgs)}")
print(matches)
return display_result(imgs, matches)
except Exception as e:
return f"<p style='color:red;'>Not found. Try another search</p>"
def update_textbox(choice):
return choice
def text_process(search_string):
search_words = search_string.title().split()
# pattern = r"(?=.*\b" + r"\b)(?=.*\b".join(map(re.escape, search_words)) + r"\b)"
pattern = r"(?=.*" + r")(?=.*".join(map(re.escape, search_words)) + r")"
filtered_items = [item for item in item_list if re.search(pattern, item)]
return gr.update(visible=True), gr.update(choices=filtered_items, value=filtered_items[0] if filtered_items else "")
with gr.Blocks() as demo:
gr.Markdown("# Get Fashion Items Recommended Based On Your Search..\n"
"## Recommender System implemented based Pinecone Vector Database with Dense & Sparse Embeddings and Hybrid Search..")
with gr.Row():
text_input = gr.Textbox(label="Type-in what you are looking for..")
submit_btn = gr.Button("Click this button for further filtering..")
dropdown = gr.Dropdown(label="Click here and select to narrow your serach..",
value= "Select an item from this list or start typing", allow_custom_value=True, interactive=True, visible=False)
slider = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5, label="Adjust the Slider to get better recommendations that suit what you are looking for..", interactive=True)
dropdown.change(fn=update_textbox, inputs=dropdown, outputs=text_input)
html_output = gr.HTML(label="Relevant Images")
submit_btn.click(fn=text_process, inputs=[text_input], outputs=[dropdown, dropdown])
text_input.change(fn=process_input, inputs=[text_input, slider], outputs=html_output)
slider.change(fn=process_input, inputs=[text_input, slider], outputs=html_output)
demo.launch(debug=True, share=True) |