Samuel Schmidt
Update src/app.py
5a5b371
raw
history blame
5.09 kB
from colordescriptor import ColorDescriptor
from CLIP import CLIPImageEncoder
from LBP import LBPImageEncoder
import gradio as gr
import os
import cv2
import numpy as np
from datasets import *
dataset = load_dataset("huggan/CelebA-faces", download_mode='force_redownload')
dataset.cleanup_cache_files()
candidate_subset = dataset["train"].select(range(10)) # This is a small CBIR app! :D
def index_dataset(dataset):
print(dataset)
print("LBP Embeddings")
lbp_model = LBPImageEncoder(8,2)
dataset_with_embeddings = dataset.map(lambda row: {'lbp_embeddings': lbp_model.describe(row["image"])})
print("Color Embeddings")
cd = ColorDescriptor((8, 12, 3))
dataset_with_embeddings = dataset_with_embeddings.map(lambda row: {'color_embeddings': cd.describe(row["image"])})
print("CLIP Embeddings")
clip_model = CLIPImageEncoder()
dataset_with_embeddings = dataset_with_embeddings.map(clip_model.encode_images, batched=True, batch_size=16)
# Add index
dataset_with_embeddings.add_faiss_index(column='color_embeddings')
dataset_with_embeddings.save_faiss_index('color_embeddings', 'color_index.faiss')
dataset_with_embeddings.add_faiss_index(column='clip_embeddings')
dataset_with_embeddings.add_faiss_index(column='lbp_embeddings')
dataset_with_embeddings.save_faiss_index('clip_embeddings', 'clip_index.faiss')
print(dataset_with_embeddings)
return dataset_with_embeddings
def check_index(ds):
index_path = "my_index.faiss"
if os.path.isfile('color_index.faiss') and os.path.isfile('clip_index.faiss'):
ds.load_faiss_index('color_embeddings', 'color_index.faiss')
return ds.load_faiss_index('clip_embeddings', 'clip_index.faiss')
else:
return index_dataset(ds)
dataset_with_embeddings = check_index(candidate_subset)
# Main function, to find similar images
# TODO: implement different distance measures
def get_neighbors(query_image, selected_descriptor, top_k=5):
"""Returns the top k nearest examples to the query image.
Args:
query_image: A PIL object representing the query image.
top_k: An integer representing the number of nearest examples to return.
Returns:
A list of the top_k most similar images as PIL objects.
"""
if "Color Descriptor" in selected_descriptor:
cd = ColorDescriptor((8, 12, 3))
qi_embedding = cd.describe(query_image)
qi_np = np.array(qi_embedding)
scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples(
'color_embeddings', qi_np, k=top_k)
images = retrieved_examples['image'] #retrieved images is a dict, with images and embeddings
return images
if "CLIP" in selected_descriptor:
clip_model = CLIPImageEncoder()
qi_embedding = clip_model.encode_image(query_image)
scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples(
'clip_embeddings', qi_embedding, k=top_k)
images = retrieved_examples['image']
return images
if "LBP" in selected_descriptor:
lbp_model = LBPImageEncoder(8,2)
qi_embedding = lbp_model.describe(query_image)
scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples(
'lbp_embeddings', qi_embedding, k=top_k)
images = retrieved_examples['image']
return images
else:
print("This descriptor is not yet supported :(")
return []
def load_cbir_dataset(datasetname, size=1000):
pass
# Define the Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("""
# Welcome to this CBIR app
This is a CBIR app focused on the retrieval of similar faces.
## Find similar images
Here you can upload an image, that is compared with existing image in our dataset.
""")
with gr.Row():
image_input = gr.Image(type="pil", label="Please upload your image")
gallery_output = gr.Gallery()
btn = gr.Button(value="Submit")
gr.Markdown("""
## Settings
Here you can adjust how the images are found
""")
with gr.Row():
checkboxes_descr = gr.CheckboxGroup(["Color Descriptor", "LBP", "CLIP"], label="Please choose an descriptor")
dataset_dropdown = gr.Dropdown(
["huggan/CelebA-faces", "EIT/cbir-eit"],
value=["huggan/CelebA-faces"]
)
btn_index = gr.Button(value="Switch Dataset")
btn_index.click(load_cbir_dataset, inputs=[dataset_dropdown])
btn.click(get_neighbors, inputs=[image_input, checkboxes_descr], outputs=[gallery_output])
# gr.Markdown(
# """
# ## Upload your own data to the CBIR
# WARNING! Please be aware, that what you are uploading here, will be public to anyone.
# Don't upload images for which you don't have the rights to do so.
# """)
# file_output = gr.File()
# upload_button = gr.UploadButton("Click to upload a file", file_types=["image"], file_count="multiple")
# upload_button.upload(upload_file, upload_button, file_output)
if __name__ == "__main__":
demo.launch()