|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer |
|
|
|
|
|
class EmbeddingModel: |
|
|
|
def __init__(self, model_path=None): |
|
if not model_path: |
|
model_path = "openai/clip-vit-base-patch32" |
|
|
|
self.model = CLIPModel.from_pretrained(model_path) |
|
self.processor = CLIPProcessor.from_pretrained(model_path) |
|
self.tokenizer = CLIPTokenizer.from_pretrained(model_path) |
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(self.device) |
|
|
|
def get_text_templates(self, text_search): |
|
text_templates = ['A photo of a {}.', |
|
'a photo of the {}.', |
|
'a bad photo of a {}.', |
|
'a photo of many {}.', |
|
'a low resolution photo of the {}.', |
|
'a photo of my {}.', |
|
'a close-up photo of a {}.', |
|
'a cropped photo of a {}.', |
|
'a photo of the {}.', |
|
'a good photo of the {}.', |
|
'a photo of one {}.', |
|
'a close-up photo of the {}.', |
|
'a photo of a {}.', |
|
'the {} in a video game.', |
|
'a origami {}.', |
|
'a low resolution photo of a {}.', |
|
'a photo of a large {}.', |
|
'a blurry photo of a {}.', |
|
'a sketch of the {}.', |
|
'a pixelated photo of a {}.', |
|
'a good photo of a {}.', |
|
'a drawing of the {}.', |
|
'a photo of a small {}.', |
|
] |
|
|
|
text_inputs = [template.format(text_search) for template in text_templates] |
|
return text_inputs |
|
|
|
|
|
def encode_text(self, text_input, apply_templates=True): |
|
|
|
if apply_templates: |
|
text_input = self.get_text_templates(text_input) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
inputs = self.processor(text_input, return_tensors="pt", padding=True) |
|
text_embeddings = self.model.get_text_features(**inputs).detach().cpu().numpy() |
|
|
|
|
|
text_embeddings /= np.linalg.norm(text_embeddings, axis=1, keepdims=True) |
|
|
|
return text_embeddings |
|
|
|
|
|
def encode_images(self, img_paths_list, normalize=True, nobg=False): |
|
if not isinstance(img_paths_list, list): |
|
img_paths_list = [img_paths_list] |
|
|
|
imgs = [Image.open(img_path) for img_path in img_paths_list] |
|
inputs = self.processor(images=imgs, return_tensors="pt")["pixel_values"] |
|
|
|
|
|
with torch.no_grad(): |
|
img_embeddings = self.model.get_image_features(pixel_values=inputs) |
|
img_embeddings = img_embeddings.detach().cpu().numpy() |
|
|
|
|
|
if normalize: |
|
img_embeddings /= np.linalg.norm(img_embeddings, axis=1, keepdims=True) |
|
|
|
return img_embeddings |
|
|
|
|
|
def get_similar_images_indexes(self, img_embeddings_np, text_search, n=5, apply_templates=True): |
|
|
|
text_embeddings = self.encode_text(text_search, apply_templates=apply_templates) |
|
|
|
|
|
similarity = np.dot(text_embeddings, img_embeddings_np.T) |
|
|
|
|
|
results = (-similarity).argsort()[0] |
|
|
|
|
|
return results[:n] |