File size: 5,409 Bytes
28379d5
0261bdd
28379d5
 
 
 
 
 
 
bc13c1f
28379d5
 
 
 
32fabcd
df5d192
 
 
 
28379d5
 
 
 
 
 
df5d192
28379d5
 
 
 
 
 
 
 
 
16dfef3
28379d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc13c1f
28379d5
 
 
 
 
 
 
 
bc13c1f
28379d5
 
 
 
 
 
 
bc13c1f
28379d5
 
 
 
 
 
 
 
 
bc13c1f
 
28379d5
bc13c1f
28379d5
bc13c1f
c70c834
40948e5
bc13c1f
 
 
40948e5
bc13c1f
 
40948e5
 
 
 
 
 
 
bc13c1f
40948e5
 
 
bc13c1f
40948e5
 
 
bc13c1f
 
 
 
40948e5
bc13c1f
40948e5
 
 
 
 
 
bc13c1f
40948e5
 
af693b0
40948e5
bc13c1f
 
 
40948e5
 
 
 
 
 
 
bc13c1f
40948e5
 
2dcffa6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import pinecone
import streamlit as st
from datasets import load_dataset
import requests
from transformers import BertTokenizerFast
from sentence_transformers import SentenceTransformer
import transformers.models.clip.image_processing_clip
import torch
import gradio as gr
from deep_translator import GoogleTranslator, single_detection
import shutil
from PIL import Image
import os

pkey = st.secrets["PINECONE_KEY"]




with open('pinecone_text.py' ,'w') as fb:
    fb.write(requests.get('https://storage.googleapis.com/gareth-pinecone-datasets/pinecone_text.py').text)
import pinecone_text

# init connection to pinecone
pinecone.init(
    api_key=pkey,  # app.pinecone.io
    environment="asia-southeast1-gcp-free"  # find next to api key
)

index_name = "hybrid-image-search"
index = pinecone.GRPCIndex(index_name)

# load the dataset from huggingface datasets hub
fashion = load_dataset(
    "ashraq/fashion-product-images-small",
    split='train[:10000]'
)

images = fashion["image"]
metadata = fashion.remove_columns("image")

# load bert tokenizer from huggingface
tokenizer = BertTokenizerFast.from_pretrained(
    'bert-base-uncased'
)

def tokenize_func(text):
    token_ids = tokenizer(
        text,
        add_special_tokens=False
    )['input_ids']
    return tokenizer.convert_ids_to_tokens(token_ids)

bm25 = pinecone_text.BM25(tokenize_func)
bm25.fit(metadata['productDisplayName'])


device = 'cuda' if torch.cuda.is_available() else 'cpu'

# load a CLIP model from huggingface
model = SentenceTransformer(
    'sentence-transformers/clip-ViT-B-32',
    device=device
)

def hybrid_scale(dense, sparse, alpha: float):
    if alpha < 0 or alpha > 1:
        raise ValueError("Alpha must be between 0 and 1")
    # scale sparse and dense vectors to create hybrid search vecs
    hsparse = {
        'indices': sparse['indices'],
        'values':  [v * (1 - alpha) for v in sparse['values']]
    }
    hdense = [v * alpha for v in dense]
    return hdense, hsparse


def text_to_image(query, alpha, k_results):
  sparse = bm25.transform_query(query)
  dense = model.encode(query).tolist()

  # scale sparse and dense vectors
  hdense, hsparse = hybrid_scale(dense, sparse, alpha=alpha)

  # search
  result = index.query(
      top_k=k_results,
      vector=hdense,
      sparse_vector=hsparse,
      include_metadata=True
  )
  # used returned product ids to get images
  imgs = [images[int(r["id"])] for r in result["matches"]]

  description = []
  for x in result["matches"]:
    description.append( x["metadata"]['productDisplayName'] )

  return imgs, description


def img_to_file_list(imgs):
  path = "searches"
  sub_path = './' + path + '/' + 'search' + '_' + str(counter["dir_num"])

  # Check whether the specified path exists or not
  isExist = os.path.exists('.'+'/'+path)

  if not isExist:
    print("Directory does not exists")
  # Create a new directory because it does not exist
    os.makedirs('.'+'/'+path, exist_ok = True)
    print("The new directory is created!")

  # Check whether the specified path exists or not
  isExist = os.path.exists(sub_path)

  if isExist:
    shutil.rmtree(sub_path)

  os.makedirs(sub_path, exist_ok = True)

  img_files = {'search'+str(counter["dir_num"]):[]}
  i = 0

  for img in imgs:
    img.save(sub_path+"/img_" + str(i) + ".png","PNG")
    img_files['search'+str(counter["dir_num"])].append(sub_path + '/' + 'img_'+ str(i) + ".png")
    i+=1

  counter["dir_num"]+=1

  return img_files['search'+str(counter["dir_num"]-1)]

counter = {"dir_num": 1}
img_files = {'x':[]}

K = 5

def fake_gan(text, alpha):
    detected_language = single_detection(text, api_key='d259a6dab3bb73b1d1c2bcc6fb62b9f4')

    if  detected_language == 'iw':
      text_eng=GoogleTranslator(source='iw', target='en').translate(text)
      imgs, descr = text_to_image(text_eng, alpha, K)

    elif detected_language == 'en':
      imgs, descr = text_to_image(text, alpha, K)


    img_files = img_to_file_list(imgs)
    return img_files

def fake_text(text, alpha):
    en_text = GoogleTranslator(source='iw', target='en').translate(text)
    img , descr = text_to_image(en_text, alpha, K)
    return descr


with gr.Blocks(width = 300) as demo:

    with gr.Row():

      text = gr.Textbox(
          value = "blue jeans for men",
          label="Enter the product characteristics:"
      )

      alpha = gr.Slider(0, 1, step=0.01, label='Choose alpha:', value = 0.05)

    with gr.Row():
        btn = gr.Button("Generate image")

    with gr.Row():
        gallery = gr.Gallery(
            label="Generated images", show_label=False, elem_id="gallery", columns=[8], rows=[1], object_fit='scale-down', height=160)

    with gr.Row():
        selected = gr.Textbox(label="Product description: ", interactive=False, value = "   The product description will appear here ",placeholder="Selected")

    # show the results in gallery on enter key and button press

    text.submit(fake_gan, inputs=[text, alpha], outputs=gallery)
    btn.click(fake_gan, inputs=[text, alpha], outputs=gallery)

    def get_select_index(evt: gr.SelectData,text,alpha):
        print(evt.index)
        eng_text = fake_text(text, alpha)[evt.index]
        #heb_text = GoogleTranslator(source='en', target='iw').translate(eng_text)
        return eng_text

    gallery.select( fn=get_select_index, inputs=[text,alpha], outputs=selected )

demo.launch(inline=False, width = 700)