embeddings_app / app.py
licmajster's picture
Uploaded app and requirements for embeddings app.
eefe517 verified
raw
history blame
3.08 kB
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import pickle
import gradio as gr
class ImageDataset(Dataset):
def __init__(self, image_dir, processor):
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
self.processor = processor
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx])
return self.processor(images=image, return_tensors="pt")['pixel_values'][0]
def get_and_save_clip_embeddings(image_dir, output_file, batch_size=32, device='cuda'):
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
dataset = ImageDataset(image_dir, processor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
all_embeddings = []
image_paths = []
model.eval()
with torch.no_grad():
for batch_idx, batch in enumerate(dataloader):
batch = batch.to(device)
embeddings = model.get_image_features(pixel_values=batch)
all_embeddings.append(embeddings.cpu().numpy())
start_idx = batch_idx * batch_size
end_idx = start_idx + len(batch)
image_paths.extend(dataset.image_paths[start_idx:end_idx])
all_embeddings = np.concatenate(all_embeddings)
with open(output_file, 'wb') as f:
pickle.dump({'embeddings': all_embeddings, 'image_paths': image_paths}, f)
# image_dir = "dataset/"
# output_file = "image_embeddings.pkl"
# batch_size = 32
# device = "cuda" if torch.cuda.is_available() else "cpu"
# get_and_save_clip_embeddings(image_dir, output_file, batch_size, device)
# APP
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
with open('image_embeddings.pkl', 'rb') as f:
f = pickle.load(f)
image_embeddings = f['embeddings']
image_names = f['image_paths']
image_paths = 'dataset'
def cosine_similarity(a, b):
a = a / np.linalg.norm(a, axis=-1, keepdims=True)
b = b / np.linalg.norm(b, axis=-1, keepdims=True)
return np.dot(a, b.T)
def find_similar_images(text):
inputs = processor(text=[text], return_tensors="pt", padding=True)
with torch.no_grad():
text_embedding = model.get_text_features(**inputs).cpu().numpy()
similarities = cosine_similarity(text_embedding, image_embeddings)
top_indices = np.argsort(similarities[0])[::-1][:4]
top_images = [image_names[i] for i in top_indices]
return top_images
text_input = gr.Textbox(label="Input text", placeholder="Enter the images description")
imgs_output = gr.Gallery(label="Top 4 most similar images")
intf = gr.Interface(fn=find_similar_images, inputs=text_input, outputs=imgs_output)
intf.launch(share=True)