Spaces:
Runtime error
Runtime error
File size: 4,237 Bytes
ac1c6ae 25ab7f8 ac1c6ae 230d2d3 ac1c6ae 0cbab2a ac1c6ae 0ab69ad ac1c6ae a329e3c cfa8379 25ab7f8 6d2b087 2112a66 25ab7f8 cfa8379 26c2408 ac1c6ae 25ab7f8 d90092b ac1c6ae cfa8379 a5c02ba ac1c6ae 26c2408 ac1c6ae 25ab7f8 ac1c6ae 25ab7f8 ac1c6ae 0ab69ad 230d2d3 0ab69ad 26c2408 0ab69ad 26c2408 0ab69ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
from colordescriptor import ColorDescriptor
from CLIP import CLIPImageEncoder
from LBP import LBPImageEncoder
import gradio as gr
import os
import cv2
import numpy as np
from datasets import *
dataset = load_dataset("huggan/CelebA-faces")
candidate_subset = dataset["train"].select(range(1000)) # This is a small CBIR app! :D
def index_dataset(dataset):
# This function might need to be split up, to reduce start-up time of app
# It could also use batches to increase speed
# If indexes are saved in files, this is all not really necessary
## Color Embeddings
cd = ColorDescriptor((8, 12, 3))
dataset_with_embeddings = dataset.map(lambda row: {'color_embeddings': cd.describe(row["image"])}) # we assume that dataset has a column 'image'
## CLIP Embeddings
clip_model = CLIPImageEncoder()
dataset_with_embeddings = dataset_with_embeddings.map(clip_model.encode_images, batched=True, batch_size=16)
## LBP Embeddings
lbp_model = LBPImageEncoder(8,2)
dataset_with_embeddings = dataset_with_embeddings.map(lambda row: {'lbp_embeddings': lbp_model.preprocess_img(row["image"])})
# Add index
dataset_with_embeddings.add_faiss_index(column='color_embeddings')
dataset_with_embeddings.save_faiss_index('color_embeddings', 'color_index.faiss')
dataset_with_embeddings.add_faiss_index(column='clip_embeddings')
dataset_with_embeddings.add_faiss_index(column='lbp_embeddings')
dataset_with_embeddings.save_faiss_index('clip_embeddings', 'clip_index.faiss')
print(dataset_with_embeddings)
return dataset_with_embeddings
def check_index(ds):
index_path = "my_index.faiss"
if os.path.isfile('color_index.faiss') and os.path.isfile('clip_index.faiss'):
ds.load_faiss_index('color_embeddings', 'color_index.faiss')
return ds.load_faiss_index('clip_embeddings', 'clip_index.faiss')
else:
return index_dataset(ds)
dataset_with_embeddings = check_index(candidate_subset)
# Main function, to find similar images
# TODO: allow different descriptor/embedding functions
# TODO: implement different distance measures
def get_neighbors(query_image, selected_descriptor, top_k=5):
"""Returns the top k nearest examples to the query image.
Args:
query_image: A PIL object representing the query image.
top_k: An integer representing the number of nearest examples to return.
Returns:
A list of the top_k most similar images as PIL objects.
"""
if "Color Descriptor" in selected_descriptor:
cd = ColorDescriptor((8, 12, 3))
qi_embedding = cd.describe(query_image)
qi_np = np.array(qi_embedding)
scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples(
'color_embeddings', qi_np, k=top_k)
images = retrieved_examples['image'] #retrieved images is a dict, with images and embeddings
return images
if "CLIP" in selected_descriptor:
clip_model = CLIPImageEncoder()
qi_embedding = clip_model.encode_image(query_image)
scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples(
'clip_embeddings', qi_embedding, k=top_k)
images = retrieved_examples['image']
return images
if "LBP" in selected_descriptor:
lbp_model = LBPImageEncoder(8,2)
qi_embedding = lbp_model.preprocess_img(query_image)
scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples(
'lbp_embeddings', qi_embedding, k=top_k)
images = retrieved_examples['image']
return images
else:
print("This descriptor is not yet supported :(")
return []
# Define the Gradio Interface
with gr.Blocks() as demo:
image_input = gr.Image(type="pil", label="Please upload an image")
checkboxes_descr = gr.CheckboxGroup(["Color Descriptor", "LBP", "CLIP"], label="Please choose an descriptor")
btn = gr.Button(value="Submit")
gallery_output = gr.Gallery()
btn.click(get_neighbors, inputs=[image_input, checkboxes_descr], outputs=[gallery_output])
btn_index = gr.Button(value="Re-index Dataset")
btn_index.click(index_dataset)
if __name__ == "__main__":
demo.launch()
|