File size: 4,027 Bytes
ac1c6ae
 
25ab7f8
ac1c6ae
230d2d3
ac1c6ae
 
 
 
1f0f3e5
 
6ee3bc1
e64c54e
ac1c6ae
0ab69ad
ddfac1f
 
 
 
 
ac1c6ae
8e7c132
ac1c6ae
ddfac1f
ac1c6ae
8e7c132
ac1c6ae
a329e3c
cfa8379
25ab7f8
cfa8379
 
26c2408
ac1c6ae
25ab7f8
d90092b
ac1c6ae
cfa8379
a5c02ba
ac1c6ae
 
26c2408
 
 
 
 
 
 
 
 
 
 
 
ac1c6ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25ab7f8
 
c4eb027
25ab7f8
 
 
 
ac1c6ae
 
 
 
 
 
 
0ab69ad
 
230d2d3
0ab69ad
 
 
26c2408
0ab69ad
 
26c2408
0ab69ad
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from colordescriptor import ColorDescriptor
from CLIP import CLIPImageEncoder
from LBP import LBPImageEncoder
import gradio as gr
import os
import cv2
import numpy as np
from datasets import *


dataset = load_dataset("huggan/CelebA-faces", download_mode='force_redownload')
dataset.cleanup_cache_files()
candidate_subset = dataset["train"].select(range(10)) # This is a small CBIR app! :D

def index_dataset(dataset):
    print(dataset)

    print("LBP Embeddings")
    lbp_model = LBPImageEncoder(8,2)
    dataset_with_embeddings = dataset.map(lambda row: {'lbp_embeddings': lbp_model.describe(row["image"])})

    print("Color Embeddings")
    cd = ColorDescriptor((8, 12, 3))
    dataset_with_embeddings = dataset_with_embeddings.map(lambda row: {'color_embeddings': cd.describe(row["image"])})

    print("CLIP Embeddings")
    clip_model = CLIPImageEncoder()
    dataset_with_embeddings = dataset_with_embeddings.map(clip_model.encode_images, batched=True, batch_size=16)


    # Add index
    dataset_with_embeddings.add_faiss_index(column='color_embeddings')
    dataset_with_embeddings.save_faiss_index('color_embeddings', 'color_index.faiss')
    dataset_with_embeddings.add_faiss_index(column='clip_embeddings')
    dataset_with_embeddings.add_faiss_index(column='lbp_embeddings')
    dataset_with_embeddings.save_faiss_index('clip_embeddings', 'clip_index.faiss')


    print(dataset_with_embeddings)
    return dataset_with_embeddings


def check_index(ds):
    index_path = "my_index.faiss"
    if os.path.isfile('color_index.faiss') and os.path.isfile('clip_index.faiss'):
        ds.load_faiss_index('color_embeddings', 'color_index.faiss')
        return ds.load_faiss_index('clip_embeddings', 'clip_index.faiss')

    else:
        return index_dataset(ds)


dataset_with_embeddings = check_index(candidate_subset)

# Main function, to find similar images
# TODO: implement different distance measures

def get_neighbors(query_image, selected_descriptor, top_k=5):
    """Returns the top k nearest examples to the query image.

    Args:
        query_image: A PIL object representing the query image.
        top_k: An integer representing the number of nearest examples to return.

    Returns:
        A list of the top_k most similar images as PIL objects.
    """
    if  "Color Descriptor" in selected_descriptor:
        cd = ColorDescriptor((8, 12, 3))
        qi_embedding = cd.describe(query_image)
        qi_np = np.array(qi_embedding)
        scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples(
            'color_embeddings', qi_np, k=top_k)
        images = retrieved_examples['image'] #retrieved images is a dict, with images and embeddings
        return images
    if "CLIP" in selected_descriptor:
        clip_model = CLIPImageEncoder()
        qi_embedding = clip_model.encode_image(query_image)
        scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples(
            'clip_embeddings', qi_embedding, k=top_k)
        images = retrieved_examples['image']
        return images
    if "LBP" in selected_descriptor:
        lbp_model = LBPImageEncoder(8,2)
        qi_embedding = lbp_model.describe(query_image)
        scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples(
            'lbp_embeddings', qi_embedding, k=top_k)
        images = retrieved_examples['image']
        return images
    else:
        print("This descriptor is not yet supported :(")
        return []


# Define the Gradio Interface

with gr.Blocks() as demo:
    image_input = gr.Image(type="pil", label="Please upload an image")
    checkboxes_descr = gr.CheckboxGroup(["Color Descriptor", "LBP", "CLIP"], label="Please choose an descriptor")

    btn = gr.Button(value="Submit")
    gallery_output = gr.Gallery()
    btn.click(get_neighbors, inputs=[image_input, checkboxes_descr], outputs=[gallery_output])

    btn_index = gr.Button(value="Re-index Dataset")
    btn_index.click(index_dataset)

if __name__ == "__main__":
    demo.launch()