import numpy as np from PIL import Image import torch from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer class EmbeddingModel: def __init__(self, model_path=None): if not model_path: model_path = "openai/clip-vit-base-patch32" self.model = CLIPModel.from_pretrained(model_path) self.processor = CLIPProcessor.from_pretrained(model_path) self.tokenizer = CLIPTokenizer.from_pretrained(model_path) self.device = "cuda" if torch.cuda.is_available() else "cpu" print(self.device) def get_text_templates(self, text_search): text_templates = ['A photo of a {}.', 'a photo of the {}.', 'a bad photo of a {}.', 'a photo of many {}.', 'a low resolution photo of the {}.', 'a photo of my {}.', 'a close-up photo of a {}.', 'a cropped photo of a {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a photo of one {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'the {} in a video game.', 'a origami {}.', 'a low resolution photo of a {}.', 'a photo of a large {}.', 'a blurry photo of a {}.', 'a sketch of the {}.', 'a pixelated photo of a {}.', 'a good photo of a {}.', 'a drawing of the {}.', 'a photo of a small {}.', ] text_inputs = [template.format(text_search) for template in text_templates] # format with class return text_inputs def encode_text(self, text_input, apply_templates=True): # If apply_templates is True, apply text templates to the text input if apply_templates: text_input = self.get_text_templates(text_input) # Get text embeddings with torch.no_grad(): # Encode and normalize the description using CLIP inputs = self.processor(text_input, return_tensors="pt", padding=True) text_embeddings = self.model.get_text_features(**inputs).detach().cpu().numpy() # Normalize text embeddings text_embeddings /= np.linalg.norm(text_embeddings, axis=1, keepdims=True) return text_embeddings def encode_images(self, img_paths_list, normalize=True, nobg=False): if not isinstance(img_paths_list, list): img_paths_list = [img_paths_list] imgs = [Image.open(img_path) for img_path in img_paths_list] inputs = self.processor(images=imgs, return_tensors="pt")["pixel_values"] # Get image embeddings with torch.no_grad(): img_embeddings = self.model.get_image_features(pixel_values=inputs) img_embeddings = img_embeddings.detach().cpu().numpy() # Normalize image embeddings if normalize: img_embeddings /= np.linalg.norm(img_embeddings, axis=1, keepdims=True) return img_embeddings def get_similar_images_indexes(self, img_embeddings_np, text_search, n=5, apply_templates=True): # Get text embeddings text_embeddings = self.encode_text(text_search, apply_templates=apply_templates) # Compute cosine similarity between image and text embeddings similarity = np.dot(text_embeddings, img_embeddings_np.T) # Sort results by similarity and reverse results = (-similarity).argsort()[0] # Return indexes of the most similar images return results[:n]