File size: 902 Bytes
ac1c6ae
cfe24db
ac1c6ae
 
 
 
 
564548c
ac1c6ae
 
 
cfe24db
 
 
 
fbfbcf2
 
 
 
 
7db538c
b56a778
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from transformers import AutoProcessor, CLIPModel
import torch


class CLIPImageEncoder:
    def __init__(self, device="cpu"):
        self.device = device
        self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

    def encode_image(self, image_pil):
        with torch.no_grad():
            input = self.processor(images=image_pil, return_tensors="pt")
            image_features = self.model.get_image_features(**input)
            return image_features.cpu().detach().numpy()[0]

    def encode_images(self, batch):
        images = batch["image"]
        input = self.processor(images=images, return_tensors="pt")
        with torch.no_grad():
            image_features = self.model.get_image_features(**input)
        return {"clip_embeddings": image_features.cpu().detach().numpy()}