from colordescriptor import ColorDescriptor from CLIP import CLIPImageEncoder import gradio as gr import cv2 import numpy as np from datasets import * dataset = load_dataset("huggan/CelebA-faces") candidate_subset = dataset["train"].select(range(40)) # This is a small CBIR app! :D def emb_dataset(dataset): # This function might need to be split up, to reduce start-up time of app # It could also use batches to increase speed # If indexes are saved in files, this is all not really necessary ## Color Embeddings cd = ColorDescriptor((8, 12, 3)) dataset_with_embeddings = dataset.map(lambda row: {'color_embeddings': cd.describe(row["image"])}) # we assume that dataset has a column 'image' ## CLIP Embeddings clip_model = CLIPImageEncoder() dataset_with_embeddings = dataset_with_embeddings.map(clip_model.encode_images, batched=True, batch_size=8) # Add index dataset_with_embeddings.add_faiss_index(column='color_embeddings') dataset_with_embeddings.add_faiss_index(column='clip_embeddings') print(dataset_with_embeddings) return dataset_with_embeddings dataset_with_embeddings = emb_dataset(candidate_subset) # Main function, to find similar images # TODO: allow different descriptor/embedding functions # TODO: implement different distance measures def get_neighbors(query_image, selected_descriptor, top_k=5): """Returns the top k nearest examples to the query image. Args: query_image: A PIL object representing the query image. top_k: An integer representing the number of nearest examples to return. Returns: A list of the top_k most similar images as PIL objects. """ if "Color Descriptor" in selected_descriptor: cd = ColorDescriptor((8, 12, 3)) qi_embedding = cd.describe(query_image) qi_np = np.array(qi_embedding) scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples( 'color_embeddings', qi_np, k=top_k) images = retrieved_examples['image'] #retrieved images is a dict, with images and embeddings return images if "CLIP" in selected_descriptor: clip_model = CLIPImageEncoder() qi_embedding = clip_model.encode_image(query_image) scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples( 'clip_embeddings', qi_embedding, k=top_k) images = retrieved_examples['image'] return images else: print("This descriptor is not yet supported :(") return [] # Define the Gradio Interface iface = gr.Interface( fn=get_neighbors, inputs=[ gr.Image(type="pil", label="Your Image"), gr.CheckboxGroup(["Color Descriptor", "LBP", "CLIP"], label="Descriptor method?"), ], outputs=gr.Gallery(), title="Image Similarity Gallery", description="Upload an image and get similar images", allow_flagging="never" ) # Launch the Gradio interface iface.launch()