File size: 3,529 Bytes
89702d9
 
 
 
f2b0ee5
89702d9
 
f2b0ee5
89702d9
f2b0ee5
 
 
 
89702d9
f2b0ee5
 
 
 
 
 
 
89702d9
f2b0ee5
89702d9
f2b0ee5
 
 
 
 
 
 
 
 
89702d9
 
 
f2b0ee5
 
 
 
89702d9
f2b0ee5
 
 
 
 
 
89702d9
f2b0ee5
89702d9
f2b0ee5
 
 
 
 
 
 
 
 
 
 
 
 
89702d9
f2b0ee5
 
 
89702d9
f2b0ee5
 
 
 
 
89702d9
f2b0ee5
89702d9
f2b0ee5
 
 
89702d9
f2b0ee5
 
 
 
 
 
 
 
 
 
89702d9
f2b0ee5
89702d9
f2b0ee5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import pandas as pd
import torch
from transformers import CLIPProcessor, CLIPModel
from sklearn.metrics.pairwise import cosine_similarity
from PIL import Image, UnidentifiedImageError
import gradio as gr
from pathlib import Path
import os

# Imposta device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Carica modello CLIP
def load_clip_model(device):
    try:
        model = CLIPModel.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K").to(device)
        processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K")
        return model, processor
    except Exception as e:
        print("Errore nel caricamento del modello CLIP:", e)
        raise e

# Carica embeddings dal CSV
def load_embeddings(embedding_file):
    try:
        df = pd.read_csv(embedding_file)
        assert 'filename' in df.columns, "La colonna 'filename' è obbligatoria nel CSV"
        embeddings = df.drop(columns=['filename']).values
        image_paths = df['filename'].tolist()
        return embeddings, image_paths
    except Exception as e:
        print("Errore nel caricamento degli embeddings:", e)
        raise e

# Trova immagini simili
def query_images(text, model, processor, image_embeddings, image_paths, device):
    try:
        text_inputs = processor(text=[text], return_tensors="pt", padding=True).to(device)
        with torch.no_grad():
            text_embedding = model.get_text_features(**text_inputs).cpu().numpy().flatten()

        similarities = cosine_similarity([text_embedding], image_embeddings)[0]
        top_indices = similarities.argsort()[-3:][::-1]
        return [(Path("img") / image_paths[i], similarities[i]) for i in top_indices]
    except Exception as e:
        print("Errore nel calcolo delle similarità:", e)
        return []

# Funzione di predizione
def predict(query_text):
    try:
        similar_images = query_images(query_text, model, processor, embeddings, image_paths, device)
        image_outputs = []
        scores = []

        for img_path, score in similar_images:
            try:
                img = Image.open(img_path).convert("RGB")
                image_outputs.append(img)
                scores.append(score)
            except (FileNotFoundError, UnidentifiedImageError) as e:
                print(f"Errore nell'apertura immagine {img_path}: {e}")
                continue

        if not image_outputs:
            # Nessuna immagine caricabile
            return [], pd.DataFrame([["Nessuna immagine trovata"]], columns=["Errore"])

        df_scores = pd.DataFrame(scores, columns=["Similarity Score"])
        return image_outputs, df_scores
    except Exception as e:
        print("Errore durante la predizione:", e)
        return [], pd.DataFrame([["Errore interno"]], columns=["Errore"])

# Esecuzione
if __name__ == "__main__":
    try:
        model, processor = load_clip_model(device)
        embeddings, image_paths = load_embeddings("embeddings.csv")

        interface = gr.Interface(
            fn=predict,
            inputs=gr.Textbox(label="Inserisci il testo"),
            outputs=[
                gr.Gallery(label="Immagini simili"),
                gr.Dataframe(label="Punteggi di similarità")
            ],
            title="Ricerca immagini simili con CLIP",
            description="Inserisci un testo per trovare le immagini più affini nel database."
        )

        interface.launch(share=True)

    except Exception as e:
        print("Errore durante l'inizializzazione dell'app:", e)