ClipArte / app.py
paolodegasperis's picture
Update app.py
f2b0ee5 verified
import pandas as pd
import torch
from transformers import CLIPProcessor, CLIPModel
from sklearn.metrics.pairwise import cosine_similarity
from PIL import Image, UnidentifiedImageError
import gradio as gr
from pathlib import Path
import os
# Imposta device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Carica modello CLIP
def load_clip_model(device):
try:
model = CLIPModel.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K").to(device)
processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K")
return model, processor
except Exception as e:
print("Errore nel caricamento del modello CLIP:", e)
raise e
# Carica embeddings dal CSV
def load_embeddings(embedding_file):
try:
df = pd.read_csv(embedding_file)
assert 'filename' in df.columns, "La colonna 'filename' è obbligatoria nel CSV"
embeddings = df.drop(columns=['filename']).values
image_paths = df['filename'].tolist()
return embeddings, image_paths
except Exception as e:
print("Errore nel caricamento degli embeddings:", e)
raise e
# Trova immagini simili
def query_images(text, model, processor, image_embeddings, image_paths, device):
try:
text_inputs = processor(text=[text], return_tensors="pt", padding=True).to(device)
with torch.no_grad():
text_embedding = model.get_text_features(**text_inputs).cpu().numpy().flatten()
similarities = cosine_similarity([text_embedding], image_embeddings)[0]
top_indices = similarities.argsort()[-3:][::-1]
return [(Path("img") / image_paths[i], similarities[i]) for i in top_indices]
except Exception as e:
print("Errore nel calcolo delle similarità:", e)
return []
# Funzione di predizione
def predict(query_text):
try:
similar_images = query_images(query_text, model, processor, embeddings, image_paths, device)
image_outputs = []
scores = []
for img_path, score in similar_images:
try:
img = Image.open(img_path).convert("RGB")
image_outputs.append(img)
scores.append(score)
except (FileNotFoundError, UnidentifiedImageError) as e:
print(f"Errore nell'apertura immagine {img_path}: {e}")
continue
if not image_outputs:
# Nessuna immagine caricabile
return [], pd.DataFrame([["Nessuna immagine trovata"]], columns=["Errore"])
df_scores = pd.DataFrame(scores, columns=["Similarity Score"])
return image_outputs, df_scores
except Exception as e:
print("Errore durante la predizione:", e)
return [], pd.DataFrame([["Errore interno"]], columns=["Errore"])
# Esecuzione
if __name__ == "__main__":
try:
model, processor = load_clip_model(device)
embeddings, image_paths = load_embeddings("embeddings.csv")
interface = gr.Interface(
fn=predict,
inputs=gr.Textbox(label="Inserisci il testo"),
outputs=[
gr.Gallery(label="Immagini simili"),
gr.Dataframe(label="Punteggi di similarità")
],
title="Ricerca immagini simili con CLIP",
description="Inserisci un testo per trovare le immagini più affini nel database."
)
interface.launch(share=True)
except Exception as e:
print("Errore durante l'inizializzazione dell'app:", e)