Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
import torchvision.transforms as T | |
from PIL import Image | |
from .abstract_embedder import AbstractImageEmbedder | |
class DinoV2Embedder(AbstractImageEmbedder): | |
def __init__(self, device: str = "cpu"): | |
"""Embedder using DINOv2 embeddings. | |
""" | |
super().__init__(device) | |
self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(self.device) | |
self.model.eval() | |
self.transforms = T.Compose([ | |
T.Resize((256, 256), interpolation=T.InterpolationMode.BICUBIC), | |
T.CenterCrop(224), | |
T.ToTensor(), | |
T.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225] | |
) | |
]) | |
def embed(self, image: Image) -> np.ndarray: | |
image = image.convert("RGB") | |
image = self.transforms(image).unsqueeze(0).to(self.device) | |
with torch.no_grad(): | |
output = self.model(image)[0].cpu().numpy() | |
return output | |