Spaces:
Runtime error
Runtime error
import numpy as np | |
from PIL import Image | |
from transformers import ViTFeatureExtractor, ViTModel | |
from .abstract_embedder import AbstractImageEmbedder | |
class DinoEmbedder(AbstractImageEmbedder): | |
def __init__(self, device: str = "cpu", model_name: str = "facebook/dino-vitb8"): | |
super().__init__(device) | |
self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) | |
self.model = ViTModel.from_pretrained(model_name).to(self.device) | |
def embed(self, image: Image) -> np.ndarray: | |
inputs = self.feature_extractor(images=image, return_tensors="pt") | |
for key in inputs: | |
inputs[key] = inputs[key].to(self.device) | |
outputs = self.model(**inputs) | |
last_hidden_states = outputs.last_hidden_state.to("cpu").numpy() | |
return last_hidden_states | |