File size: 3,403 Bytes
eefe517
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fba3d6
eefe517
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
734f486
 
eefe517
 
 
734f486
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import pickle
import gradio as gr

class ImageDataset(Dataset):
    def __init__(self, image_dir, processor):
        self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
        self.processor = processor

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx])
        return self.processor(images=image, return_tensors="pt")['pixel_values'][0]

def get_and_save_clip_embeddings(image_dir, output_file, batch_size=32, device='cuda'):
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    dataset = ImageDataset(image_dir, processor)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    all_embeddings = []
    image_paths = []

    model.eval()
    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            batch = batch.to(device)
            embeddings = model.get_image_features(pixel_values=batch)
            all_embeddings.append(embeddings.cpu().numpy())
            start_idx = batch_idx * batch_size
            end_idx = start_idx + len(batch)
            image_paths.extend(dataset.image_paths[start_idx:end_idx])

    all_embeddings = np.concatenate(all_embeddings)

    with open(output_file, 'wb') as f:
        pickle.dump({'embeddings': all_embeddings, 'image_paths': image_paths}, f)

# image_dir = "dataset/"
# output_file = "image_embeddings.pkl"
# batch_size = 32
# device = "cuda" if torch.cuda.is_available() else "cpu"

# get_and_save_clip_embeddings(image_dir, output_file, batch_size, device)


# APP

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

with open('image_embeddings.pkl', 'rb') as f:
        f = pickle.load(f)
        image_embeddings = f['embeddings']
        image_names = f['image_paths']
image_paths = './dataset'

def cosine_similarity(a, b):
    a = a / np.linalg.norm(a, axis=-1, keepdims=True)
    b = b / np.linalg.norm(b, axis=-1, keepdims=True)
    return np.dot(a, b.T)

def find_similar_images(text):
    inputs = processor(text=[text], return_tensors="pt", padding=True)
    with torch.no_grad():
        text_embedding = model.get_text_features(**inputs).cpu().numpy()

    similarities = cosine_similarity(text_embedding, image_embeddings)
    top_indices = np.argsort(similarities[0])[::-1][:4]
    top_images = [image_names[i] for i in top_indices]
    
    return top_images



text_input = gr.Textbox(label="Input text", placeholder="Enter the images description")
imgs_output = gr.Gallery(label="Top 4 most similar images")

intf = gr.Interface(
    fn=find_similar_images,
    inputs=gr.Textbox(label="Input Text", placeholder="Enter a description"),
    outputs=gr.Gallery(label="Top 4 Similar Images"),
)
extra_text = gr.Markdown("""
The dataset contains images of dogs, rabbits, sharks, and deer.
Displaying the images might take a couple of seconds.
""")
with gr.Blocks() as app:
    intf.render()
    extra_text.render()
app.launch(share=True)