Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -2,71 +2,94 @@ import pandas as pd
|
|
2 |
import torch
|
3 |
from transformers import CLIPProcessor, CLIPModel
|
4 |
from sklearn.metrics.pairwise import cosine_similarity
|
5 |
-
from PIL import Image
|
6 |
import gradio as gr
|
7 |
from pathlib import Path
|
|
|
8 |
|
9 |
-
#
|
|
|
|
|
|
|
10 |
def load_clip_model(device):
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
# Carica
|
16 |
def load_embeddings(embedding_file):
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# Trova immagini simili
|
23 |
def query_images(text, model, processor, image_embeddings, image_paths, device):
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
31 |
|
32 |
-
# Funzione
|
33 |
def predict(query_text):
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
image_outputs.append(img)
|
42 |
-
scores.append(score)
|
43 |
-
except Exception as e:
|
44 |
-
print(f"Errore nell'apertura immagine {img_path}: {e}")
|
45 |
-
continue
|
46 |
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
49 |
|
50 |
-
#
|
51 |
if __name__ == "__main__":
|
52 |
-
|
53 |
-
|
|
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
-
|
60 |
-
interface = gr.Interface(
|
61 |
-
fn=predict,
|
62 |
-
inputs=gr.Textbox(label="Enter your text"),
|
63 |
-
outputs=[
|
64 |
-
gr.Gallery(label="Top 3 Similar Images"),
|
65 |
-
gr.Dataframe(label="Similarity Scores")
|
66 |
-
],
|
67 |
-
title="CLIP Image Finder",
|
68 |
-
description="Enter a textual description to find the most similar images using CLIP."
|
69 |
-
)
|
70 |
|
71 |
-
|
72 |
-
|
|
|
2 |
import torch
|
3 |
from transformers import CLIPProcessor, CLIPModel
|
4 |
from sklearn.metrics.pairwise import cosine_similarity
|
5 |
+
from PIL import Image, UnidentifiedImageError
|
6 |
import gradio as gr
|
7 |
from pathlib import Path
|
8 |
+
import os
|
9 |
|
10 |
+
# Imposta device
|
11 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
+
|
13 |
+
# Carica modello CLIP
|
14 |
def load_clip_model(device):
|
15 |
+
try:
|
16 |
+
model = CLIPModel.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K").to(device)
|
17 |
+
processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K")
|
18 |
+
return model, processor
|
19 |
+
except Exception as e:
|
20 |
+
print("Errore nel caricamento del modello CLIP:", e)
|
21 |
+
raise e
|
22 |
|
23 |
+
# Carica embeddings dal CSV
|
24 |
def load_embeddings(embedding_file):
|
25 |
+
try:
|
26 |
+
df = pd.read_csv(embedding_file)
|
27 |
+
assert 'filename' in df.columns, "La colonna 'filename' è obbligatoria nel CSV"
|
28 |
+
embeddings = df.drop(columns=['filename']).values
|
29 |
+
image_paths = df['filename'].tolist()
|
30 |
+
return embeddings, image_paths
|
31 |
+
except Exception as e:
|
32 |
+
print("Errore nel caricamento degli embeddings:", e)
|
33 |
+
raise e
|
34 |
|
35 |
# Trova immagini simili
|
36 |
def query_images(text, model, processor, image_embeddings, image_paths, device):
|
37 |
+
try:
|
38 |
+
text_inputs = processor(text=[text], return_tensors="pt", padding=True).to(device)
|
39 |
+
with torch.no_grad():
|
40 |
+
text_embedding = model.get_text_features(**text_inputs).cpu().numpy().flatten()
|
41 |
|
42 |
+
similarities = cosine_similarity([text_embedding], image_embeddings)[0]
|
43 |
+
top_indices = similarities.argsort()[-3:][::-1]
|
44 |
+
return [(Path("img") / image_paths[i], similarities[i]) for i in top_indices]
|
45 |
+
except Exception as e:
|
46 |
+
print("Errore nel calcolo delle similarità:", e)
|
47 |
+
return []
|
48 |
|
49 |
+
# Funzione di predizione
|
50 |
def predict(query_text):
|
51 |
+
try:
|
52 |
+
similar_images = query_images(query_text, model, processor, embeddings, image_paths, device)
|
53 |
+
image_outputs = []
|
54 |
+
scores = []
|
55 |
+
|
56 |
+
for img_path, score in similar_images:
|
57 |
+
try:
|
58 |
+
img = Image.open(img_path).convert("RGB")
|
59 |
+
image_outputs.append(img)
|
60 |
+
scores.append(score)
|
61 |
+
except (FileNotFoundError, UnidentifiedImageError) as e:
|
62 |
+
print(f"Errore nell'apertura immagine {img_path}: {e}")
|
63 |
+
continue
|
64 |
|
65 |
+
if not image_outputs:
|
66 |
+
# Nessuna immagine caricabile
|
67 |
+
return [], pd.DataFrame([["Nessuna immagine trovata"]], columns=["Errore"])
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
+
df_scores = pd.DataFrame(scores, columns=["Similarity Score"])
|
70 |
+
return image_outputs, df_scores
|
71 |
+
except Exception as e:
|
72 |
+
print("Errore durante la predizione:", e)
|
73 |
+
return [], pd.DataFrame([["Errore interno"]], columns=["Errore"])
|
74 |
|
75 |
+
# Esecuzione
|
76 |
if __name__ == "__main__":
|
77 |
+
try:
|
78 |
+
model, processor = load_clip_model(device)
|
79 |
+
embeddings, image_paths = load_embeddings("embeddings.csv")
|
80 |
|
81 |
+
interface = gr.Interface(
|
82 |
+
fn=predict,
|
83 |
+
inputs=gr.Textbox(label="Inserisci il testo"),
|
84 |
+
outputs=[
|
85 |
+
gr.Gallery(label="Immagini simili"),
|
86 |
+
gr.Dataframe(label="Punteggi di similarità")
|
87 |
+
],
|
88 |
+
title="Ricerca immagini simili con CLIP",
|
89 |
+
description="Inserisci un testo per trovare le immagini più affini nel database."
|
90 |
+
)
|
91 |
|
92 |
+
interface.launch(share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
+
except Exception as e:
|
95 |
+
print("Errore durante l'inizializzazione dell'app:", e)
|