Bastien Dechamps
salut
944c93a
raw
history blame
821 Bytes
import numpy as np
from PIL import Image
from transformers import ViTFeatureExtractor, ViTModel
from .abstract_embedder import AbstractImageEmbedder
class DinoEmbedder(AbstractImageEmbedder):
def __init__(self, device: str = "cpu", model_name: str = "facebook/dino-vitb8"):
super().__init__(device)
self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
self.model = ViTModel.from_pretrained(model_name).to(self.device)
def embed(self, image: Image) -> np.ndarray:
inputs = self.feature_extractor(images=image, return_tensors="pt")
for key in inputs:
inputs[key] = inputs[key].to(self.device)
outputs = self.model(**inputs)
last_hidden_states = outputs.last_hidden_state.to("cpu").numpy()
return last_hidden_states