davertor's picture
First model version
ab37b58
import numpy as np
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
class EmbeddingModel:
def __init__(self, model_path=None):
if not model_path:
model_path = "openai/clip-vit-base-patch32"
self.model = CLIPModel.from_pretrained(model_path)
self.processor = CLIPProcessor.from_pretrained(model_path)
self.tokenizer = CLIPTokenizer.from_pretrained(model_path)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(self.device)
def get_text_templates(self, text_search):
text_templates = ['A photo of a {}.',
'a photo of the {}.',
'a bad photo of a {}.',
'a photo of many {}.',
'a low resolution photo of the {}.',
'a photo of my {}.',
'a close-up photo of a {}.',
'a cropped photo of a {}.',
'a photo of the {}.',
'a good photo of the {}.',
'a photo of one {}.',
'a close-up photo of the {}.',
'a photo of a {}.',
'the {} in a video game.',
'a origami {}.',
'a low resolution photo of a {}.',
'a photo of a large {}.',
'a blurry photo of a {}.',
'a sketch of the {}.',
'a pixelated photo of a {}.',
'a good photo of a {}.',
'a drawing of the {}.',
'a photo of a small {}.',
]
text_inputs = [template.format(text_search) for template in text_templates] # format with class
return text_inputs
def encode_text(self, text_input, apply_templates=True):
# If apply_templates is True, apply text templates to the text input
if apply_templates:
text_input = self.get_text_templates(text_input)
# Get text embeddings
with torch.no_grad():
# Encode and normalize the description using CLIP
inputs = self.processor(text_input, return_tensors="pt", padding=True)
text_embeddings = self.model.get_text_features(**inputs).detach().cpu().numpy()
# Normalize text embeddings
text_embeddings /= np.linalg.norm(text_embeddings, axis=1, keepdims=True)
return text_embeddings
def encode_images(self, img_paths_list, normalize=True, nobg=False):
if not isinstance(img_paths_list, list):
img_paths_list = [img_paths_list]
imgs = [Image.open(img_path) for img_path in img_paths_list]
inputs = self.processor(images=imgs, return_tensors="pt")["pixel_values"]
# Get image embeddings
with torch.no_grad():
img_embeddings = self.model.get_image_features(pixel_values=inputs)
img_embeddings = img_embeddings.detach().cpu().numpy()
# Normalize image embeddings
if normalize:
img_embeddings /= np.linalg.norm(img_embeddings, axis=1, keepdims=True)
return img_embeddings
def get_similar_images_indexes(self, img_embeddings_np, text_search, n=5, apply_templates=True):
# Get text embeddings
text_embeddings = self.encode_text(text_search, apply_templates=apply_templates)
# Compute cosine similarity between image and text embeddings
similarity = np.dot(text_embeddings, img_embeddings_np.T)
# Sort results by similarity and reverse
results = (-similarity).argsort()[0]
# Return indexes of the most similar images
return results[:n]