Spaces:
Sleeping
Sleeping
import torch | |
from PIL import Image | |
from transformers import CLIPProcessor, CLIPModel | |
from torch.utils.data import Dataset, DataLoader | |
import os | |
import numpy as np | |
import pickle | |
import gradio as gr | |
class ImageDataset(Dataset): | |
def __init__(self, image_dir, processor): | |
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))] | |
self.processor = processor | |
def __len__(self): | |
return len(self.image_paths) | |
def __getitem__(self, idx): | |
image = Image.open(self.image_paths[idx]) | |
return self.processor(images=image, return_tensors="pt")['pixel_values'][0] | |
def get_and_save_clip_embeddings(image_dir, output_file, batch_size=32, device='cuda'): | |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
dataset = ImageDataset(image_dir, processor) | |
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4) | |
all_embeddings = [] | |
image_paths = [] | |
model.eval() | |
with torch.no_grad(): | |
for batch_idx, batch in enumerate(dataloader): | |
batch = batch.to(device) | |
embeddings = model.get_image_features(pixel_values=batch) | |
all_embeddings.append(embeddings.cpu().numpy()) | |
start_idx = batch_idx * batch_size | |
end_idx = start_idx + len(batch) | |
image_paths.extend(dataset.image_paths[start_idx:end_idx]) | |
all_embeddings = np.concatenate(all_embeddings) | |
with open(output_file, 'wb') as f: | |
pickle.dump({'embeddings': all_embeddings, 'image_paths': image_paths}, f) | |
# image_dir = "dataset/" | |
# output_file = "image_embeddings.pkl" | |
# batch_size = 32 | |
# device = "cuda" if torch.cuda.is_available() else "cpu" | |
# get_and_save_clip_embeddings(image_dir, output_file, batch_size, device) | |
# APP | |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
with open('image_embeddings.pkl', 'rb') as f: | |
f = pickle.load(f) | |
image_embeddings = f['embeddings'] | |
image_names = f['image_paths'] | |
image_paths = 'dataset' | |
def cosine_similarity(a, b): | |
a = a / np.linalg.norm(a, axis=-1, keepdims=True) | |
b = b / np.linalg.norm(b, axis=-1, keepdims=True) | |
return np.dot(a, b.T) | |
def find_similar_images(text): | |
inputs = processor(text=[text], return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
text_embedding = model.get_text_features(**inputs).cpu().numpy() | |
similarities = cosine_similarity(text_embedding, image_embeddings) | |
top_indices = np.argsort(similarities[0])[::-1][:4] | |
top_images = [image_names[i] for i in top_indices] | |
return top_images | |
text_input = gr.Textbox(label="Input text", placeholder="Enter the images description") | |
imgs_output = gr.Gallery(label="Top 4 most similar images") | |
intf = gr.Interface(fn=find_similar_images, inputs=text_input, outputs=imgs_output) | |
intf.launch(share=True) | |