import torch from PIL import Image from transformers import CLIPProcessor, CLIPModel from torch.utils.data import Dataset, DataLoader import os import numpy as np import pickle import gradio as gr class ImageDataset(Dataset): def __init__(self, image_dir, processor): self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))] self.processor = processor def __len__(self): return len(self.image_paths) def __getitem__(self, idx): image = Image.open(self.image_paths[idx]) return self.processor(images=image, return_tensors="pt")['pixel_values'][0] def get_and_save_clip_embeddings(image_dir, output_file, batch_size=32, device='cuda'): model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") dataset = ImageDataset(image_dir, processor) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4) all_embeddings = [] image_paths = [] model.eval() with torch.no_grad(): for batch_idx, batch in enumerate(dataloader): batch = batch.to(device) embeddings = model.get_image_features(pixel_values=batch) all_embeddings.append(embeddings.cpu().numpy()) start_idx = batch_idx * batch_size end_idx = start_idx + len(batch) image_paths.extend(dataset.image_paths[start_idx:end_idx]) all_embeddings = np.concatenate(all_embeddings) with open(output_file, 'wb') as f: pickle.dump({'embeddings': all_embeddings, 'image_paths': image_paths}, f) # image_dir = "dataset/" # output_file = "image_embeddings.pkl" # batch_size = 32 # device = "cuda" if torch.cuda.is_available() else "cpu" # get_and_save_clip_embeddings(image_dir, output_file, batch_size, device) # APP model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") with open('image_embeddings.pkl', 'rb') as f: f = pickle.load(f) image_embeddings = f['embeddings'] image_names = f['image_paths'] image_paths = './dataset' def cosine_similarity(a, b): a = a / np.linalg.norm(a, axis=-1, keepdims=True) b = b / np.linalg.norm(b, axis=-1, keepdims=True) return np.dot(a, b.T) def find_similar_images(text): inputs = processor(text=[text], return_tensors="pt", padding=True) with torch.no_grad(): text_embedding = model.get_text_features(**inputs).cpu().numpy() similarities = cosine_similarity(text_embedding, image_embeddings) top_indices = np.argsort(similarities[0])[::-1][:4] top_images = [image_names[i] for i in top_indices] return top_images text_input = gr.Textbox(label="Input text", placeholder="Enter the images description") imgs_output = gr.Gallery(label="Top 4 most similar images") intf = gr.Interface( fn=find_similar_images, inputs=gr.Textbox(label="Input Text", placeholder="Enter a description"), outputs=gr.Gallery(label="Top 4 Similar Images"), ) extra_text = gr.Markdown(""" The dataset contains images of dogs, rabbits, sharks, and deer. Displaying the images might take a couple of seconds. """) with gr.Blocks() as app: intf.render() extra_text.render() app.launch(share=True)