paolodegasperis commited on
Commit
f2b0ee5
·
verified ·
1 Parent(s): 89702d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -49
app.py CHANGED
@@ -2,71 +2,94 @@ import pandas as pd
2
  import torch
3
  from transformers import CLIPProcessor, CLIPModel
4
  from sklearn.metrics.pairwise import cosine_similarity
5
- from PIL import Image
6
  import gradio as gr
7
  from pathlib import Path
 
8
 
9
- # Carica modello e processor CLIP
 
 
 
10
  def load_clip_model(device):
11
- model = CLIPModel.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K").to(device)
12
- processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K")
13
- return model, processor
 
 
 
 
14
 
15
- # Carica embedding da CSV
16
  def load_embeddings(embedding_file):
17
- df = pd.read_csv(embedding_file)
18
- embeddings = df.iloc[:, 1:].values # Esclude la colonna 'filename'
19
- image_paths = df['filename'].tolist()
20
- return embeddings, image_paths
 
 
 
 
 
21
 
22
  # Trova immagini simili
23
  def query_images(text, model, processor, image_embeddings, image_paths, device):
24
- text_inputs = processor(text=[text], return_tensors="pt", padding=True).to(device)
25
- with torch.no_grad():
26
- text_embedding = model.get_text_features(**text_inputs).cpu().numpy().flatten()
 
27
 
28
- similarities = cosine_similarity([text_embedding], image_embeddings)[0]
29
- top_indices = similarities.argsort()[-3:][::-1]
30
- return [(Path("img") / image_paths[i], similarities[i]) for i in top_indices]
 
 
 
31
 
32
- # Funzione per Gradio
33
  def predict(query_text):
34
- similar_images = query_images(query_text, model, processor, embeddings, image_paths, device)
35
- image_outputs = []
36
- scores = []
 
 
 
 
 
 
 
 
 
 
37
 
38
- for img_path, score in similar_images:
39
- try:
40
- img = Image.open(img_path).convert("RGB")
41
- image_outputs.append(img)
42
- scores.append(score)
43
- except Exception as e:
44
- print(f"Errore nell'apertura immagine {img_path}: {e}")
45
- continue
46
 
47
- df_scores = pd.DataFrame(scores, columns=["Similarity Score"])
48
- return image_outputs, df_scores
 
 
 
49
 
50
- # Avvio Gradio
51
  if __name__ == "__main__":
52
- device = "cuda" if torch.cuda.is_available() else "cpu"
53
- model, processor = load_clip_model(device)
 
54
 
55
- # Carica embedding
56
- embedding_file = "embeddings.csv"
57
- embeddings, image_paths = load_embeddings(embedding_file)
 
 
 
 
 
 
 
58
 
59
- # Interfaccia Gradio
60
- interface = gr.Interface(
61
- fn=predict,
62
- inputs=gr.Textbox(label="Enter your text"),
63
- outputs=[
64
- gr.Gallery(label="Top 3 Similar Images"),
65
- gr.Dataframe(label="Similarity Scores")
66
- ],
67
- title="CLIP Image Finder",
68
- description="Enter a textual description to find the most similar images using CLIP."
69
- )
70
 
71
- # Per Hugging Face è obbligatorio `share=True`
72
- interface.launch(share=True)
 
2
  import torch
3
  from transformers import CLIPProcessor, CLIPModel
4
  from sklearn.metrics.pairwise import cosine_similarity
5
+ from PIL import Image, UnidentifiedImageError
6
  import gradio as gr
7
  from pathlib import Path
8
+ import os
9
 
10
+ # Imposta device
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ # Carica modello CLIP
14
  def load_clip_model(device):
15
+ try:
16
+ model = CLIPModel.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K").to(device)
17
+ processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K")
18
+ return model, processor
19
+ except Exception as e:
20
+ print("Errore nel caricamento del modello CLIP:", e)
21
+ raise e
22
 
23
+ # Carica embeddings dal CSV
24
  def load_embeddings(embedding_file):
25
+ try:
26
+ df = pd.read_csv(embedding_file)
27
+ assert 'filename' in df.columns, "La colonna 'filename' è obbligatoria nel CSV"
28
+ embeddings = df.drop(columns=['filename']).values
29
+ image_paths = df['filename'].tolist()
30
+ return embeddings, image_paths
31
+ except Exception as e:
32
+ print("Errore nel caricamento degli embeddings:", e)
33
+ raise e
34
 
35
  # Trova immagini simili
36
  def query_images(text, model, processor, image_embeddings, image_paths, device):
37
+ try:
38
+ text_inputs = processor(text=[text], return_tensors="pt", padding=True).to(device)
39
+ with torch.no_grad():
40
+ text_embedding = model.get_text_features(**text_inputs).cpu().numpy().flatten()
41
 
42
+ similarities = cosine_similarity([text_embedding], image_embeddings)[0]
43
+ top_indices = similarities.argsort()[-3:][::-1]
44
+ return [(Path("img") / image_paths[i], similarities[i]) for i in top_indices]
45
+ except Exception as e:
46
+ print("Errore nel calcolo delle similarità:", e)
47
+ return []
48
 
49
+ # Funzione di predizione
50
  def predict(query_text):
51
+ try:
52
+ similar_images = query_images(query_text, model, processor, embeddings, image_paths, device)
53
+ image_outputs = []
54
+ scores = []
55
+
56
+ for img_path, score in similar_images:
57
+ try:
58
+ img = Image.open(img_path).convert("RGB")
59
+ image_outputs.append(img)
60
+ scores.append(score)
61
+ except (FileNotFoundError, UnidentifiedImageError) as e:
62
+ print(f"Errore nell'apertura immagine {img_path}: {e}")
63
+ continue
64
 
65
+ if not image_outputs:
66
+ # Nessuna immagine caricabile
67
+ return [], pd.DataFrame([["Nessuna immagine trovata"]], columns=["Errore"])
 
 
 
 
 
68
 
69
+ df_scores = pd.DataFrame(scores, columns=["Similarity Score"])
70
+ return image_outputs, df_scores
71
+ except Exception as e:
72
+ print("Errore durante la predizione:", e)
73
+ return [], pd.DataFrame([["Errore interno"]], columns=["Errore"])
74
 
75
+ # Esecuzione
76
  if __name__ == "__main__":
77
+ try:
78
+ model, processor = load_clip_model(device)
79
+ embeddings, image_paths = load_embeddings("embeddings.csv")
80
 
81
+ interface = gr.Interface(
82
+ fn=predict,
83
+ inputs=gr.Textbox(label="Inserisci il testo"),
84
+ outputs=[
85
+ gr.Gallery(label="Immagini simili"),
86
+ gr.Dataframe(label="Punteggi di similarità")
87
+ ],
88
+ title="Ricerca immagini simili con CLIP",
89
+ description="Inserisci un testo per trovare le immagini più affini nel database."
90
+ )
91
 
92
+ interface.launch(share=True)
 
 
 
 
 
 
 
 
 
 
93
 
94
+ except Exception as e:
95
+ print("Errore durante l'inizializzazione dell'app:", e)