Spaces:
Runtime error
Runtime error
File size: 821 Bytes
944c93a 4388025 944c93a 4388025 944c93a 4388025 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import numpy as np
from PIL import Image
from transformers import ViTFeatureExtractor, ViTModel
from .abstract_embedder import AbstractImageEmbedder
class DinoEmbedder(AbstractImageEmbedder):
def __init__(self, device: str = "cpu", model_name: str = "facebook/dino-vitb8"):
super().__init__(device)
self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
self.model = ViTModel.from_pretrained(model_name).to(self.device)
def embed(self, image: Image) -> np.ndarray:
inputs = self.feature_extractor(images=image, return_tensors="pt")
for key in inputs:
inputs[key] = inputs[key].to(self.device)
outputs = self.model(**inputs)
last_hidden_states = outputs.last_hidden_state.to("cpu").numpy()
return last_hidden_states
|