File size: 3,884 Bytes
ab37b58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import numpy as np
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer


class EmbeddingModel:
    
    def __init__(self, model_path=None):
        if not model_path:
            model_path = "openai/clip-vit-base-patch32"
            
        self.model = CLIPModel.from_pretrained(model_path)
        self.processor = CLIPProcessor.from_pretrained(model_path)
        self.tokenizer = CLIPTokenizer.from_pretrained(model_path)
        
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(self.device)
        
    def get_text_templates(self, text_search):
        text_templates = ['A photo of a {}.',
                    'a photo of the {}.',
                    'a bad photo of a {}.',
                    'a photo of many {}.',
                    'a low resolution photo of the {}.',
                    'a photo of my {}.',
                    'a close-up photo of a {}.',
                    'a cropped photo of a {}.',
                    'a photo of the {}.',
                    'a good photo of the {}.',
                    'a photo of one {}.',
                    'a close-up photo of the {}.',
                    'a photo of a {}.',
                    'the {} in a video game.',
                    'a origami {}.',
                    'a low resolution photo of a {}.',
                    'a photo of a large {}.',
                    'a blurry photo of a {}.',
                    'a sketch of the {}.',
                    'a pixelated photo of a {}.',
                    'a good photo of a {}.',
                    'a drawing of the {}.',
                    'a photo of a small {}.',
                    ]
            
        text_inputs = [template.format(text_search) for template in text_templates] # format with class
        return text_inputs


    def encode_text(self, text_input, apply_templates=True):
        # If apply_templates is True, apply text templates to the text input
        if apply_templates:
            text_input = self.get_text_templates(text_input)
        
        # Get text embeddings 
        with torch.no_grad():
            # Encode and normalize the description using CLIP
            inputs = self.processor(text_input, return_tensors="pt", padding=True)
            text_embeddings =  self.model.get_text_features(**inputs).detach().cpu().numpy()
            
            # Normalize text embeddings
            text_embeddings /= np.linalg.norm(text_embeddings, axis=1, keepdims=True)
        
        return text_embeddings


    def encode_images(self, img_paths_list, normalize=True, nobg=False):
        if not isinstance(img_paths_list, list):
            img_paths_list = [img_paths_list]

        imgs = [Image.open(img_path) for img_path in img_paths_list]    
        inputs = self.processor(images=imgs, return_tensors="pt")["pixel_values"]
        
        # Get image embeddings
        with torch.no_grad():
            img_embeddings = self.model.get_image_features(pixel_values=inputs)
            img_embeddings = img_embeddings.detach().cpu().numpy()
            
            # Normalize image embeddings
            if normalize:
                img_embeddings /= np.linalg.norm(img_embeddings, axis=1, keepdims=True)
            
        return img_embeddings


    def get_similar_images_indexes(self, img_embeddings_np, text_search, n=5, apply_templates=True):
        # Get text embeddings
        text_embeddings = self.encode_text(text_search, apply_templates=apply_templates)
        
        # Compute cosine similarity between image and text embeddings
        similarity = np.dot(text_embeddings, img_embeddings_np.T)
        
        # Sort results by similarity and reverse
        results = (-similarity).argsort()[0]
        
        # Return indexes of the most similar images
        return results[:n]