|
import pinecone |
|
import streamlit as st |
|
from datasets import load_dataset |
|
import requests |
|
from transformers import BertTokenizerFast |
|
from sentence_transformers import SentenceTransformer |
|
import transformers.models.clip.image_processing_clip |
|
import torch |
|
import gradio as gr |
|
from deep_translator import GoogleTranslator, single_detection |
|
import shutil |
|
from PIL import Image |
|
import os |
|
|
|
pkey = os.environ.get("PINECONE_KEY") |
|
|
|
|
|
|
|
|
|
with open('pinecone_text.py' ,'w') as fb: |
|
fb.write(requests.get('https://storage.googleapis.com/gareth-pinecone-datasets/pinecone_text.py').text) |
|
import pinecone_text |
|
|
|
|
|
pinecone.init( |
|
api_key=pkey, |
|
environment="asia-southeast1-gcp-free" |
|
) |
|
|
|
index_name = "hybrid-image-search" |
|
index = pinecone.GRPCIndex(index_name) |
|
|
|
|
|
fashion = load_dataset( |
|
"ashraq/fashion-product-images-small", |
|
split='train[:10000]' |
|
) |
|
|
|
images = fashion["image"] |
|
metadata = fashion.remove_columns("image") |
|
|
|
|
|
tokenizer = BertTokenizerFast.from_pretrained( |
|
'bert-base-uncased' |
|
) |
|
|
|
def tokenize_func(text): |
|
token_ids = tokenizer( |
|
text, |
|
add_special_tokens=False |
|
)['input_ids'] |
|
return tokenizer.convert_ids_to_tokens(token_ids) |
|
|
|
bm25 = pinecone_text.BM25(tokenize_func) |
|
bm25.fit(metadata['productDisplayName']) |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
model = SentenceTransformer( |
|
'sentence-transformers/clip-ViT-B-32', |
|
device=device |
|
) |
|
|
|
def hybrid_scale(dense, sparse, alpha: float): |
|
if alpha < 0 or alpha > 1: |
|
raise ValueError("Alpha must be between 0 and 1") |
|
|
|
hsparse = { |
|
'indices': sparse['indices'], |
|
'values': [v * (1 - alpha) for v in sparse['values']] |
|
} |
|
hdense = [v * alpha for v in dense] |
|
return hdense, hsparse |
|
|
|
|
|
def text_to_image(query, alpha, k_results): |
|
sparse = bm25.transform_query(query) |
|
dense = model.encode(query).tolist() |
|
|
|
|
|
hdense, hsparse = hybrid_scale(dense, sparse, alpha=alpha) |
|
|
|
|
|
result = index.query( |
|
top_k=k_results, |
|
vector=hdense, |
|
sparse_vector=hsparse, |
|
include_metadata=True |
|
) |
|
|
|
imgs = [images[int(r["id"])] for r in result["matches"]] |
|
|
|
description = [] |
|
for x in result["matches"]: |
|
description.append( x["metadata"]['productDisplayName'] ) |
|
|
|
return imgs, description |
|
|
|
|
|
def img_to_file_list(imgs): |
|
path = "searches" |
|
sub_path = './' + path + '/' + 'search' + '_' + str(counter["dir_num"]) |
|
|
|
|
|
isExist = os.path.exists('.'+'/'+path) |
|
|
|
if not isExist: |
|
print("Directory does not exists") |
|
|
|
os.makedirs('.'+'/'+path, exist_ok = True) |
|
print("The new directory is created!") |
|
|
|
|
|
isExist = os.path.exists(sub_path) |
|
|
|
if isExist: |
|
shutil.rmtree(sub_path) |
|
|
|
os.makedirs(sub_path, exist_ok = True) |
|
|
|
img_files = {'search'+str(counter["dir_num"]):[]} |
|
i = 0 |
|
|
|
for img in imgs: |
|
img.save(sub_path+"/img_" + str(i) + ".png","PNG") |
|
img_files['search'+str(counter["dir_num"])].append(sub_path + '/' + 'img_'+ str(i) + ".png") |
|
i+=1 |
|
|
|
counter["dir_num"]+=1 |
|
|
|
return img_files['search'+str(counter["dir_num"]-1)] |
|
|
|
counter = {"dir_num": 1} |
|
img_files = {'x':[]} |
|
|
|
K = 5 |
|
|
|
def fake_gan(text, alpha): |
|
detected_language = single_detection(text, api_key='d259a6dab3bb73b1d1c2bcc6fb62b9f4') |
|
|
|
if detected_language == 'iw': |
|
text_eng=GoogleTranslator(source='iw', target='en').translate(text) |
|
imgs, descr = text_to_image(text_eng, alpha, K) |
|
|
|
elif detected_language == 'en': |
|
imgs, descr = text_to_image(text, alpha, K) |
|
|
|
|
|
img_files = img_to_file_list(imgs) |
|
return img_files |
|
|
|
def fake_text(text, alpha): |
|
en_text = GoogleTranslator(source='iw', target='en').translate(text) |
|
img , descr = text_to_image(en_text, alpha, K) |
|
return descr |
|
|
|
|
|
with gr.Blocks(width = 300) as demo: |
|
|
|
with gr.Row(): |
|
|
|
text = gr.Textbox( |
|
value = "blue jeans for men", |
|
label="Enter the product characteristics:" |
|
) |
|
|
|
alpha = gr.Slider(0, 1, step=0.01, label='Choose alpha:', value = 0.05) |
|
|
|
with gr.Row(): |
|
btn = gr.Button("Generate image") |
|
|
|
with gr.Row(): |
|
gallery = gr.Gallery( |
|
label="Generated images", show_label=False, elem_id="gallery", columns=[8], rows=[1], object_fit='scale-down', height=160) |
|
|
|
with gr.Row(): |
|
selected = gr.Textbox(label="Product description: ", interactive=False, value = " The product description will appear here ",placeholder="Selected") |
|
|
|
|
|
|
|
text.submit(fake_gan, inputs=[text, alpha], outputs=gallery) |
|
btn.click(fake_gan, inputs=[text, alpha], outputs=gallery) |
|
|
|
def get_select_index(evt: gr.SelectData,text,alpha): |
|
print(evt.index) |
|
eng_text = fake_text(text, alpha)[evt.index] |
|
|
|
return eng_text |
|
|
|
gallery.select( fn=get_select_index, inputs=[text,alpha], outputs=selected ) |
|
|
|
demo.launch(inline=False, width = 700) |