embeddings_app / app.py
licmajster's picture
Update to description
734f486 verified
raw
history blame
3.4 kB
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import pickle
import gradio as gr
class ImageDataset(Dataset):
def __init__(self, image_dir, processor):
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
self.processor = processor
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx])
return self.processor(images=image, return_tensors="pt")['pixel_values'][0]
def get_and_save_clip_embeddings(image_dir, output_file, batch_size=32, device='cuda'):
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
dataset = ImageDataset(image_dir, processor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
all_embeddings = []
image_paths = []
model.eval()
with torch.no_grad():
for batch_idx, batch in enumerate(dataloader):
batch = batch.to(device)
embeddings = model.get_image_features(pixel_values=batch)
all_embeddings.append(embeddings.cpu().numpy())
start_idx = batch_idx * batch_size
end_idx = start_idx + len(batch)
image_paths.extend(dataset.image_paths[start_idx:end_idx])
all_embeddings = np.concatenate(all_embeddings)
with open(output_file, 'wb') as f:
pickle.dump({'embeddings': all_embeddings, 'image_paths': image_paths}, f)
# image_dir = "dataset/"
# output_file = "image_embeddings.pkl"
# batch_size = 32
# device = "cuda" if torch.cuda.is_available() else "cpu"
# get_and_save_clip_embeddings(image_dir, output_file, batch_size, device)
# APP
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
with open('image_embeddings.pkl', 'rb') as f:
f = pickle.load(f)
image_embeddings = f['embeddings']
image_names = f['image_paths']
image_paths = './dataset'
def cosine_similarity(a, b):
a = a / np.linalg.norm(a, axis=-1, keepdims=True)
b = b / np.linalg.norm(b, axis=-1, keepdims=True)
return np.dot(a, b.T)
def find_similar_images(text):
inputs = processor(text=[text], return_tensors="pt", padding=True)
with torch.no_grad():
text_embedding = model.get_text_features(**inputs).cpu().numpy()
similarities = cosine_similarity(text_embedding, image_embeddings)
top_indices = np.argsort(similarities[0])[::-1][:4]
top_images = [image_names[i] for i in top_indices]
return top_images
text_input = gr.Textbox(label="Input text", placeholder="Enter the images description")
imgs_output = gr.Gallery(label="Top 4 most similar images")
intf = gr.Interface(
fn=find_similar_images,
inputs=gr.Textbox(label="Input Text", placeholder="Enter a description"),
outputs=gr.Gallery(label="Top 4 Similar Images"),
)
extra_text = gr.Markdown("""
The dataset contains images of dogs, rabbits, sharks, and deer.
Displaying the images might take a couple of seconds.
""")
with gr.Blocks() as app:
intf.render()
extra_text.render()
app.launch(share=True)