ClipArte / app.py
phalanx80's picture
Upload 2 files
8f20af0 verified
import pandas as pd
import torch
from transformers import CLIPProcessor, CLIPModel
from sklearn.metrics.pairwise import cosine_similarity
from PIL import Image
import gradio as gr
from pathlib import Path
def load_clip_model(device):
model = CLIPModel.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K").to(device)
processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K")
return model, processor
def load_embeddings(embedding_file):
df = pd.read_csv(embedding_file)
embeddings = df.iloc[:, 1:].values # Escludi la prima colonna (filename)
image_paths = df['filename'].tolist() # Salva i nomi dei file
return embeddings, image_paths
def query_images(text, model, processor, image_embeddings, image_paths, device):
# Genera l'embedding per il testo
text_inputs = processor(text=[text], return_tensors="pt", padding=True).to(device)
with torch.no_grad():
text_embedding = model.get_text_features(**text_inputs).cpu().numpy().flatten()
# Calcola la similarità coseno tra l'embedding del testo e gli embeddings delle immagini
similarities = cosine_similarity([text_embedding], image_embeddings)[0]
# Ottieni gli indici delle tre immagini più simili
top_indices = similarities.argsort()[-3:][::-1]
# Restituisci i percorsi delle immagini più simili e i loro punteggi
return [(Path("img") / image_paths[i], similarities[i]) for i in top_indices]
def predict(query_text):
similar_images = query_images(query_text, model, processor, embeddings, image_paths, device)
image_outputs = []
scores = []
for img_path, score in similar_images:
img = Image.open(img_path)
image_outputs.append(img)
scores.append(score)
# Formatta i punteggi per il DataFrame
scores_formatted = [[score] for score in scores] # Converti in una lista di liste
return image_outputs, scores_formatted # Restituisci le immagini e i punteggi
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
model, processor = load_clip_model(device)
# Carica gli embeddings dal file CSV
embedding_file = "embeddings.csv" # Sostituisci con il percorso corretto
embeddings, image_paths = load_embeddings(embedding_file)
# Crea l'interfaccia Gradio
interface = gr.Interface(
fn=predict,
inputs="text",
outputs=[
gr.Gallery(label="Similar Images", elem_id="image_gallery"),
gr.Dataframe(label="Similarity Scores", headers=["Score"]) # Rimosso show_footer
],
title="Find Similar Images",
description="Insert text to find the three most similar images."
)
interface.launch()