File size: 821 Bytes
944c93a
 
4388025
 
 
 
 
 
 
 
 
 
 
 
944c93a
4388025
 
944c93a
4388025
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import numpy as np
from PIL import Image
from transformers import ViTFeatureExtractor, ViTModel

from .abstract_embedder import AbstractImageEmbedder


class DinoEmbedder(AbstractImageEmbedder):
    def __init__(self, device: str = "cpu", model_name: str = "facebook/dino-vitb8"):
        super().__init__(device)
        self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
        self.model = ViTModel.from_pretrained(model_name).to(self.device)

    def embed(self, image: Image) -> np.ndarray:
        inputs = self.feature_extractor(images=image, return_tensors="pt")
        for key in inputs:
            inputs[key] = inputs[key].to(self.device)
        outputs = self.model(**inputs)
        last_hidden_states = outputs.last_hidden_state.to("cpu").numpy()
        return last_hidden_states