File size: 6,878 Bytes
460f62f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import gradio as gr
from PIL import Image
from datasets import load_dataset, Dataset
import random
import numpy as np
import time


#ds = load_dataset("tonyassi/lucy4-embeddings", split='train')
ds = load_dataset("tonyassi/finesse1-embeddings", split='train')
#ds = load_dataset("tonyassi/lucy5-embeddings", split='train')
id_to_row = {row['id']: row for row in ds}
remaining_ds = None
preference_embedding = []

###################################################################################

def get_random_images(dataset, num):
    # Select 4 random indices from the dataset
    random_indices = random.sample(range(len(dataset)), num)
    
    # Get the 4 random images
    random_images = dataset.select(random_indices)
    
    # Create a new dataset with the remaining images
    remaining_indices = [i for i in range(len(dataset)) if i not in random_indices]
    new_dataset = dataset.select(remaining_indices)
    
    return random_images, new_dataset

"""
def find_similar_images(dataset, num, embedding):
    start_time = time.time()
    # Find the most similar images in dataset
    dataset.add_faiss_index(column='embeddings')
    embedding = np.array(embedding)
    scores, retrieved_examples = dataset.get_nearest_examples('embeddings', embedding, k=num)

    print('time 2.1:', time.time()-start_time)

    # Create a new dataset without these images
    dataset.drop_index('embeddings')
    print('time 2.2:', time.time()-start_time)
    remaining_indices = [i for i in range(len(dataset)) if dataset[i]['id'] not in retrieved_examples['id']]
    print('time 2.3:', time.time()-start_time)
    new_dataset = dataset.select(remaining_indices)

    print('time 2.4:', time.time()-start_time)
    return retrieved_examples, new_dataset

"""

def find_similar_images(dataset, num, embedding):
    start_time = time.time()
    
    # Ensure FAISS index exists and search for similar images
    #if not dataset.has_faiss_index('embeddings'):
    dataset.add_faiss_index(column='embeddings')
    scores, retrieved_examples = dataset.get_nearest_examples('embeddings', np.array(embedding), k=num)
    
    print('time 2.1:', time.time()-start_time)
    
    # Drop FAISS index after use to avoid re-indexing
    dataset.drop_index('embeddings')
    print('time 2.2:', time.time()-start_time)

    # Extract all dataset IDs and use a set to find remaining indices
    dataset_ids = dataset['id']
    retrieved_ids_set = set(retrieved_examples['id'])

    # Use a list comprehension with enumerate for faster indexing
    remaining_indices = [i for i, id in enumerate(dataset_ids) if id not in retrieved_ids_set]
    
    print('time 2.3:', time.time()-start_time)

    # Create a new dataset without the retrieved images
    new_dataset = dataset.select(remaining_indices)

    print('time 2.4:', time.time()-start_time)
    return retrieved_examples, new_dataset



def average_embedding(embedding1, embedding2):
    embedding1 = np.array(embedding1)
    embedding2 = np.array(embedding2)
    return (embedding1 + embedding2) / 2

###################################################################################

def load_images():
    print('load_images()')
    print("ds", ds.num_rows)

    global remaining_ds
    remaining_ds = ds

    global preference_embedding
    preference_embedding = []

    # Get random images
    rand_imgs, remaining_ds = get_random_images(ds, 10)

    # Create a list of tuples [(img1,caption1),(img2,caption2)...]
    result = list(zip(rand_imgs['image'], [str(id) for id in rand_imgs['id']]))

    return result


def select_image(evt: gr.SelectData, gallery, preference_gallery):
    start_time = time.time()

    print('select_image()')

    global remaining_ds
    print("remaining_ds", remaining_ds.num_rows)
    
    # Selected image
    selected_id = int(evt.value['caption'])
    print('ID', selected_id)
    #selected_row = ds.filter(lambda row: row['id'] == selected_id)[0]
    selected_row = id_to_row[selected_id]
    selected_embedding = selected_row['embeddings']
    selected_image = selected_row['image']

    print('time 1:', time.time()-start_time)

    # Update preference embedding
    global preference_embedding
    if len(preference_embedding) == 0:
        preference_embedding = selected_embedding
    else: 
        preference_embedding = average_embedding(preference_embedding, selected_embedding)

    print('time 2:', time.time()-start_time)

    # Find images which are most similar to the preference embedding
    simlar_images, remaining_ds = find_similar_images(remaining_ds, 5, preference_embedding)

    print('time 3:', time.time()-start_time)

    # Create a list of tuples [(img1,caption1),(img2,caption2)...]
    result = list(zip(simlar_images['image'], [str(id) for id in simlar_images['id']]))

    print('time 4:', time.time()-start_time)

    # Get random images
    rand_imgs, remaining_ds = get_random_images(remaining_ds, 5)
    # Create a list of tuples [(img1,caption1),(img2,caption2)...]
    random_result = list(zip(rand_imgs['image'], [str(id) for id in rand_imgs['id']]))

    final_result = result + random_result

    # Update prefernce gallery
    if (preference_gallery==None):
        final_preference_gallery = [selected_image]
    else:
        final_preference_gallery = [selected_image] + preference_gallery

    print('time 5:', time.time()-start_time)

    return gr.Gallery(value=final_result, selected_index=None), final_preference_gallery

###################################################################################

with gr.Blocks() as demo:
    gr.Markdown("""
    <center><h1> Product Recommendation using Image Similarity </h1></center>

    <center>by <a href="https://www.tonyassi.com/" target="_blank">Tony Assi</a></center>


    <center> This is a demo of product recommendation using image similarity of user preferences. </center>

    The the user selects their favorite product which then gets added to the user preference group. Each of the image embeddings in the user preference products get averaged into a preference embedding. Each round some products are displayed: 5 products most similar to user preference embedding and 5 random products. Embeddings are generated with [google/vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224). The dataset used is [tonyassi/finesse1-embeddings](https://huggingface.co/datasets/tonyassi/finesse1-embeddings).
    """)

    product_gallery = gr.Gallery(columns=5, object_fit='contain', allow_preview=False, label='Products')
    preference_gallery = gr.Gallery(columns=5, object_fit='contain', allow_preview=False, label='Preference', interactive=False)

    demo.load(load_images, inputs=None, outputs=[product_gallery])
    product_gallery.select(select_image, inputs=[product_gallery, preference_gallery], outputs=[product_gallery, preference_gallery])
  

demo.launch()