|
import pandas as pd
|
|
import torch
|
|
from transformers import CLIPProcessor, CLIPModel
|
|
from sklearn.metrics.pairwise import cosine_similarity
|
|
from PIL import Image
|
|
import gradio as gr
|
|
from pathlib import Path
|
|
|
|
def load_clip_model(device):
|
|
model = CLIPModel.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K").to(device)
|
|
processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K")
|
|
return model, processor
|
|
|
|
def load_embeddings(embedding_file):
|
|
df = pd.read_csv(embedding_file)
|
|
embeddings = df.iloc[:, 1:].values
|
|
image_paths = df['filename'].tolist()
|
|
return embeddings, image_paths
|
|
|
|
def query_images(text, model, processor, image_embeddings, image_paths, device):
|
|
|
|
text_inputs = processor(text=[text], return_tensors="pt", padding=True).to(device)
|
|
with torch.no_grad():
|
|
text_embedding = model.get_text_features(**text_inputs).cpu().numpy().flatten()
|
|
|
|
|
|
similarities = cosine_similarity([text_embedding], image_embeddings)[0]
|
|
|
|
|
|
top_indices = similarities.argsort()[-3:][::-1]
|
|
|
|
|
|
return [(Path("img") / image_paths[i], similarities[i]) for i in top_indices]
|
|
|
|
def predict(query_text):
|
|
similar_images = query_images(query_text, model, processor, embeddings, image_paths, device)
|
|
image_outputs = []
|
|
scores = []
|
|
|
|
for img_path, score in similar_images:
|
|
img = Image.open(img_path)
|
|
image_outputs.append(img)
|
|
scores.append(score)
|
|
|
|
|
|
scores_formatted = [[score] for score in scores]
|
|
return image_outputs, scores_formatted
|
|
|
|
if __name__ == "__main__":
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model, processor = load_clip_model(device)
|
|
|
|
|
|
embedding_file = "embeddings.csv"
|
|
embeddings, image_paths = load_embeddings(embedding_file)
|
|
|
|
|
|
interface = gr.Interface(
|
|
fn=predict,
|
|
inputs="text",
|
|
outputs=[
|
|
gr.Gallery(label="Similar Images", elem_id="image_gallery"),
|
|
gr.Dataframe(label="Similarity Scores", headers=["Score"])
|
|
],
|
|
title="Find Similar Images",
|
|
description="Insert text to find the three most similar images."
|
|
)
|
|
|
|
interface.launch()
|
|
|