Samuel Schmidt
Update: Improve performance by using batches
fbfbcf2
raw
history blame
2.98 kB
from colordescriptor import ColorDescriptor
from CLIP import CLIPImageEncoder
import gradio as gr
import cv2
import numpy as np
from datasets import *
dataset = load_dataset("huggan/CelebA-faces")
candidate_subset = dataset["train"].select(range(40)) # This is a small CBIR app! :D
def emb_dataset(dataset):
# This function might need to be split up, to reduce start-up time of app
# It could also use batches to increase speed
# If indexes are saved in files, this is all not really necessary
## Color Embeddings
cd = ColorDescriptor((8, 12, 3))
dataset_with_embeddings = dataset.map(lambda row: {'color_embeddings': cd.describe(row["image"])}) # we assume that dataset has a column 'image'
## CLIP Embeddings
clip_model = CLIPImageEncoder()
dataset_with_embeddings = dataset_with_embeddings.map(clip_model.encode_images, batched=True, batch_size=8)
# Add index
dataset_with_embeddings.add_faiss_index(column='color_embeddings')
dataset_with_embeddings.add_faiss_index(column='clip_embeddings')
print(dataset_with_embeddings)
return dataset_with_embeddings
dataset_with_embeddings = emb_dataset(candidate_subset)
# Main function, to find similar images
# TODO: allow different descriptor/embedding functions
# TODO: implement different distance measures
def get_neighbors(query_image, selected_descriptor, top_k=5):
"""Returns the top k nearest examples to the query image.
Args:
query_image: A PIL object representing the query image.
top_k: An integer representing the number of nearest examples to return.
Returns:
A list of the top_k most similar images as PIL objects.
"""
if "Color Descriptor" in selected_descriptor:
cd = ColorDescriptor((8, 12, 3))
qi_embedding = cd.describe(query_image)
qi_np = np.array(qi_embedding)
scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples(
'color_embeddings', qi_np, k=top_k)
images = retrieved_examples['image'] #retrieved images is a dict, with images and embeddings
return images
if "CLIP" in selected_descriptor:
clip_model = CLIPImageEncoder()
qi_embedding = clip_model.encode_image(query_image)
scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples(
'clip_embeddings', qi_embedding, k=top_k)
images = retrieved_examples['image']
return images
else:
print("This descriptor is not yet supported :(")
return []
# Define the Gradio Interface
iface = gr.Interface(
fn=get_neighbors,
inputs=[
gr.Image(type="pil", label="Your Image"),
gr.CheckboxGroup(["Color Descriptor", "LBP", "CLIP"], label="Descriptor method?"),
],
outputs=gr.Gallery(),
title="Image Similarity Gallery",
description="Upload an image and get similar images",
allow_flagging="never"
)
# Launch the Gradio interface
iface.launch()